Commit e8afe91a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1e731018
...@@ -578,7 +578,7 @@ struct cpu_logsoftmax ...@@ -578,7 +578,7 @@ struct cpu_logsoftmax
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
......
...@@ -17,10 +17,10 @@ argument logsoftmax(hipStream_t stream, ...@@ -17,10 +17,10 @@ argument logsoftmax(hipStream_t stream,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = output_shape.lens();
auto num_in_batch = lens[axis]; auto num_in_batch = lens[axis];
auto batch_lens = lens; auto batch_lens = lens;
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens}; migraphx::shape batch_shape{output_shape.type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) { visit_all(args.back(), args.front())([&](auto output, auto input) {
...@@ -33,21 +33,21 @@ argument logsoftmax(hipStream_t stream, ...@@ -33,21 +33,21 @@ argument logsoftmax(hipStream_t stream,
// each thread is for one item in the batch // each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = desc_batch.multi(i); auto batch_idx = desc_batch.multi(i);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// get max // get max
auto batch_max = input_ptr[desc_data.linear(batch_idx)]; auto batch_max = input_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < num_in_batch; ++j) for(std::size_t j = 1; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); size_t idx = desc_data.linear(data_idx);
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[idx])); batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[idx]));
} }
for(std::size_t j = 0; j < num_in_batch; ++j) for(std::size_t j = 0; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); size_t idx = desc_data.linear(data_idx);
output_ptr[idx] = input_ptr[idx] - batch_max; output_ptr[idx] = input_ptr[idx] - batch_max;
} }
...@@ -55,7 +55,7 @@ argument logsoftmax(hipStream_t stream, ...@@ -55,7 +55,7 @@ argument logsoftmax(hipStream_t stream,
for(std::size_t j = 1; j < num_in_batch; ++j) for(std::size_t j = 1; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); size_t idx = desc_data.linear(data_idx);
batch_sum += ::exp(to_hip_type(output_ptr[idx])); batch_sum += ::exp(to_hip_type(output_ptr[idx]));
} }
batch_sum = ::log(to_hip_type(batch_sum)); batch_sum = ::log(to_hip_type(batch_sum));
...@@ -63,7 +63,7 @@ argument logsoftmax(hipStream_t stream, ...@@ -63,7 +63,7 @@ argument logsoftmax(hipStream_t stream,
for(std::size_t j = 0; j < num_in_batch; ++j) for(std::size_t j = 0; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); size_t idx = desc_data.linear(data_idx);
output_ptr[idx] -= batch_sum; output_ptr[idx] -= batch_sum;
} }
}); });
......
...@@ -1002,14 +1002,12 @@ TEST_CASE(logsoftmax_test_axis_0) ...@@ -1002,14 +1002,12 @@ TEST_CASE(logsoftmax_test_axis_0)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-0.135261, -2.843968, -0.659995, -0.488413, -1.051857, -2.812936, -0.135261, -2.843968, -0.659995, -0.488413, -1.051857, -2.812936, -0.250956, -0.353985,
-0.250956, -0.353985, -1.155980, -0.603651, -0.211969, -0.175371, -1.155980, -0.603651, -0.211969, -0.175371, -1.336552, -3.885010, -1.871544, -0.837083,
-1.336552, -3.885010, -1.871544, -0.837083, -0.887745, -0.433338, -0.887745, -0.433338, -1.158864, -4.911197, -1.147972, -0.666711, -0.996874, -0.981418,
-1.158864, -4.911197, -1.147972, -0.666711, -0.996874, -0.981418, -0.851145, -0.853988, -0.858112, -2.067420, -0.059956, -0.727436, -0.950881, -0.429689,
-0.851145, -0.853988, -0.858112, -2.067420, -0.059956, -0.727436, -0.061906, -1.505332, -1.210277, -0.377970, -0.791448, -1.655428, -1.827253, -0.304828,
-0.950881, -0.429689, -0.061906, -1.505332, -1.210277, -0.377970, -0.020762, -0.167101, -0.567346, -0.530319, -1.045094, -0.376648, -0.007391, -0.381670,
-0.791448, -1.655428, -1.827253, -0.304828, -0.020762, -0.167101,
-0.567346, -0.530319, -1.045094, -0.376648, -0.007391, -0.381670,
-0.720302, -0.460499, -0.469651, -0.556740, -0.554628, -0.551582}; -0.720302, -0.460499, -0.469651, -0.556740, -0.554628, -0.551582};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
...@@ -1037,14 +1035,12 @@ TEST_CASE(logsoftmax_test_axis_1) ...@@ -1037,14 +1035,12 @@ TEST_CASE(logsoftmax_test_axis_1)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-0.550468, -2.132973, -1.549746, -0.650533, -1.051529, -2.248570, -0.550468, -2.132973, -1.549746, -0.650533, -1.051529, -2.248570, -0.141017, -2.028357,
-0.141017, -2.028357, -1.947730, -1.511324, -0.166597, -0.379726, -1.947730, -1.511324, -0.166597, -0.379726, -1.965689, -1.172109, -1.475721, -2.700831,
-1.965689, -1.172109, -1.475721, -2.700831, -1.537011, -0.658754, -1.537011, -0.658754, -1.596017, -3.353137, -2.266743, -1.084197, -1.076214, -0.406712,
-1.596017, -3.353137, -2.266743, -1.084197, -1.076214, -0.406712, -2.743019, -0.425526, -1.079083, -2.139486, -1.270584, -1.024088, -1.154231, -3.201762,
-2.743019, -0.425526, -1.079083, -2.139486, -1.270584, -1.024088, -0.888957, -0.532855, -3.103583, -1.221339, -1.355980, -3.531678, -1.438510, -0.975194,
-1.154231, -3.201762, -0.888957, -0.532855, -3.103583, -1.221339, -0.080261, -1.162697, -1.568557, -1.398519, -1.322129, -0.470660, -0.370953, -0.907343,
-1.355980, -3.531678, -1.438510, -0.975194, -0.080261, -1.162697,
-1.568557, -1.398519, -1.322129, -0.470660, -0.370953, -0.907343,
-1.179017, -3.312239, -1.286363, -1.586076, -0.345100, -0.824173}; -1.179017, -3.312239, -1.286363, -1.586076, -0.345100, -0.824173};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
...@@ -1072,14 +1068,12 @@ TEST_CASE(logsoftmax_test_axis_2) ...@@ -1072,14 +1068,12 @@ TEST_CASE(logsoftmax_test_axis_2)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-0.495957, -1.031212, -0.245531, -2.013726, -1.339125, -2.465619, -0.495957, -1.031212, -0.245531, -2.013726, -1.339125, -2.465619, -1.356652, -0.964037,
-1.356652, -0.964037, -2.019250, -0.214522, -0.289569, -0.234392, -2.019250, -0.214522, -0.289569, -0.234392, -2.086591, -2.684439, -2.851651, -2.674176,
-2.086591, -2.684439, -2.851651, -2.674176, -1.697424, -1.889155, -1.697424, -1.889155, -0.401029, -3.064586, -1.173030, -1.306912, -2.177020, -0.834262,
-0.401029, -3.064586, -1.173030, -1.306912, -2.177020, -0.834262, -2.818177, -0.174415, -1.361105, -1.024571, -0.106766, -1.167645, -1.072650, -2.576522,
-2.818177, -0.174415, -1.361105, -1.024571, -0.106766, -1.167645, -0.569261, -1.207483, -3.679894, -2.095913, -0.504264, -3.039291, -1.290559, -1.156812,
-1.072650, -2.576522, -0.569261, -1.207483, -3.679894, -2.095913, -0.126453, -0.551493, -2.506384, -2.646261, -1.905195, -0.206994, -0.191369, -0.959754,
-0.504264, -3.039291, -1.290559, -1.156812, -0.126453, -0.551493,
-2.506384, -2.646261, -1.905195, -0.206994, -0.191369, -0.959754,
-1.948685, -3.671233, -0.875521, -3.111952, -1.905644, -1.6076011}; -1.948685, -3.671233, -0.875521, -3.111952, -1.905644, -1.6076011};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
...@@ -1107,14 +1101,12 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -1107,14 +1101,12 @@ TEST_CASE(logsoftmax_test_axis_3)
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618}; -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = { std::vector<float> s = {
-0.336904, -3.475825, -1.366154, -0.279366, -2.208430, -2.010934, -0.336904, -3.475825, -1.366154, -0.279366, -2.208430, -2.010934, -0.225511, -2.436562,
-0.225511, -2.436562, -2.167785, -1.572415, -1.784104, -0.470789, -2.167785, -1.572415, -1.784104, -0.470789, -1.067459, -1.801948, -0.711023, -2.307197,
-1.067459, -1.801948, -0.711023, -2.307197, -1.467087, -0.400681, -1.467087, -0.400681, -0.426983, -3.740518, -1.127681, -1.078919, -2.599005, -0.534965,
-0.426983, -3.740518, -1.127681, -1.078919, -2.599005, -0.534965, -2.561400, -0.567617, -1.033025, -2.097713, -0.520463, -1.262245, -1.763230, -2.607658,
-2.561400, -0.567617, -1.033025, -2.097713, -0.520463, -1.262245, -0.281299, -0.814243, -2.627210, -0.724131, -0.655704, -2.123055, -1.018163, -2.480634,
-1.763230, -2.607658, -0.281299, -0.814243, -2.627210, -0.724131, -0.382599, -1.451479, -1.843102, -0.915303, -0.818078, -1.316929, -0.508875, -2.033541,
-0.655704, -2.123055, -1.018163, -2.480634, -0.382599, -1.451479,
-1.843102, -0.915303, -0.818078, -1.316929, -0.508875, -2.033541,
-1.487672, -2.417791, -0.378360, -2.568531, -0.569794, -1.028032}; -1.487672, -2.417791, -0.378360, -2.568531, -0.569794, -1.028032};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment