Commit a9e43ae9 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

softmax change to be compatible with pytorch softmax.

parent ed9a29bc
...@@ -569,37 +569,22 @@ struct cpu_softmax ...@@ -569,37 +569,22 @@ struct cpu_softmax
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T> template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{ {
if(axis == 0) idx.erase(idx.begin() + axis);
{ return batch_shape.index(idx);
return 0;
}
else
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::vector<std::size_t> batch_lens{}; batch_lens.erase(batch_lens.begin() + op.axis);
if(op.axis == 0) shape batch_shape{shape::int32_type, batch_lens};
{
batch_lens.push_back(1);
}
else
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(), std::numeric_limits<value_type>::lowest());
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
...@@ -607,14 +592,12 @@ struct cpu_softmax ...@@ -607,14 +592,12 @@ struct cpu_softmax
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; output(idx.begin(), idx.end()) = std::exp(input(idx.begin(), idx.end()) - batch_max[index]);
}); });
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
auto output_val = std::exp(output(idx.begin(), idx.end()));
output(idx.begin(), idx.end()) = output_val;
batch_sum[index] += output(idx.begin(), idx.end()); batch_sum[index] += output(idx.begin(), idx.end());
}); });
......
...@@ -936,11 +936,14 @@ TEST_CASE(softmax_simple_test) ...@@ -936,11 +936,14 @@ TEST_CASE(softmax_simple_test)
std::vector<float> s = {0.377541, 0.622459}; std::vector<float> s = {0.377541, 0.622459};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
p.add_instruction(migraphx::op::softmax{}, al); p.add_instruction(migraphx::op::softmax{1}, al);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(2); std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
for (auto v : results_vector)
std::cout << v << "\t";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
......
...@@ -569,13 +569,13 @@ struct test_sub2 : verify_program<test_sub2> ...@@ -569,13 +569,13 @@ struct test_sub2 : verify_program<test_sub2>
} }
}; };
struct test_softmax : verify_program<test_softmax> struct test_softmax1 : verify_program<test_softmax1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 1, 1}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}});
p.add_instruction(migraphx::op::softmax{}, x); p.add_instruction(migraphx::op::softmax{0}, x);
return p; return p;
} }
}; };
...@@ -592,6 +592,25 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -592,6 +592,25 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis>
struct test_softmax : verify_program<test_softmax<Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
return p;
}
};
template struct test_softmax<0>;
template struct test_softmax<1>;
template struct test_softmax<2>;
template struct test_softmax<3>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment