Commit 65faffa0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent a9e43ae9
...@@ -584,7 +584,8 @@ struct cpu_softmax ...@@ -584,7 +584,8 @@ struct cpu_softmax
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::numeric_limits<value_type>::lowest()); std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
...@@ -592,12 +593,13 @@ struct cpu_softmax ...@@ -592,12 +593,13 @@ struct cpu_softmax
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = std::exp(input(idx.begin(), idx.end()) - batch_max[index]); output(idx.begin(), idx.end()) =
std::exp(input(idx.begin(), idx.end()) - batch_max[index]);
}); });
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += output(idx.begin(), idx.end()); batch_sum[index] += output(idx.begin(), idx.end());
}); });
......
...@@ -941,7 +941,7 @@ TEST_CASE(softmax_simple_test) ...@@ -941,7 +941,7 @@ TEST_CASE(softmax_simple_test)
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(2); std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for (auto v : results_vector) for(auto v : results_vector)
std::cout << v << "\t"; std::cout << v << "\t";
std::cout << std::endl; std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment