Commit 5775b579 authored by Paul's avatar Paul
Browse files

Fix softmax operator

parent 311ae1d1
...@@ -413,11 +413,6 @@ struct atan : unary ...@@ -413,11 +413,6 @@ struct atan : unary
std::string name() const { return "atan"; } std::string name() const { return "atan"; }
}; };
struct softmax : unary
{
std::string name() const { return "softmax"; }
};
struct tanh : unary struct tanh : unary
{ {
std::string name() const { return "tanh"; } std::string name() const { return "tanh"; }
...@@ -433,6 +428,16 @@ struct neg : unary ...@@ -433,6 +428,16 @@ struct neg : unary
std::string name() const { return "neg"; } std::string name() const { return "neg"; }
}; };
struct softmax
{
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
};
struct flatten struct flatten
{ {
uint64_t axis = 0; uint64_t axis = 0;
......
...@@ -53,7 +53,6 @@ struct onnx_parser ...@@ -53,7 +53,6 @@ struct onnx_parser
add_generic_op("MatMul", gemm{}); add_generic_op("MatMul", gemm{});
add_generic_op("Mul", mul{}); add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"}); add_generic_op("Relu", activation{"relu"});
add_generic_op("Softmax", softmax{});
add_generic_op("Sub", sub{}); add_generic_op("Sub", sub{});
add_generic_op("Sum", add{}); add_generic_op("Sum", add{});
...@@ -65,6 +64,7 @@ struct onnx_parser ...@@ -65,6 +64,7 @@ struct onnx_parser
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
} }
template <class F> template <class F>
...@@ -101,6 +101,15 @@ struct onnx_parser ...@@ -101,6 +101,15 @@ struct onnx_parser
}); });
} }
instruction_ref
parse_softmax(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(reshape{{long(dims[0]), 1, 1, long(dims[1])}}, args.front());
auto s = prog.add_instruction(softmax{}, r);
return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
namespace migraph { namespace migraph {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE) MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL)
struct program_impl struct program_impl
{ {
...@@ -317,8 +318,18 @@ argument generic_eval(const program& p, ...@@ -317,8 +318,18 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
return generic_eval( if(enabled(MIGRAPH_TRACE_EVAL{})) {
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); }); auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: " << ins->name() << std::endl;
return f();
});
} else {
return generic_eval(*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
} }
double common_average(const std::vector<double>& v) double common_average(const std::vector<double>& v)
......
...@@ -312,7 +312,7 @@ struct miopen_softmax ...@@ -312,7 +312,7 @@ struct miopen_softmax
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
return inputs.at(1); return op.compute_shape({inputs.at(0)});
} }
argument argument
...@@ -438,8 +438,9 @@ struct miopen_apply ...@@ -438,8 +438,9 @@ struct miopen_apply
instruction_ref apply_softmax(instruction_ref ins) instruction_ref apply_softmax(instruction_ref ins)
{ {
auto&& op = any_cast<softmax>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_softmax{}, ins->inputs().at(0), output); return prog->replace_instruction(ins, miopen_softmax{op}, ins->inputs().at(0), output);
} }
instruction_ref apply_add(instruction_ref ins) instruction_ref apply_add(instruction_ref ins)
......
...@@ -240,7 +240,18 @@ struct test_softmax ...@@ -240,7 +240,18 @@ struct test_softmax
migraph::program create_program() const migraph::program create_program() const
{ {
migraph::program p; migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}}); auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
};
struct test_softmax2
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000}});
p.add_instruction(migraph::softmax{}, x); p.add_instruction(migraph::softmax{}, x);
return p; return p;
} }
...@@ -527,6 +538,9 @@ int main() ...@@ -527,6 +538,9 @@ int main()
verify_program<test_add_broadcast3>(); verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); verify_program<test_add_broadcast5>();
verify_program<test_softmax>();
// TODO: Add reshapes to make this a valid test case
// verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment