Commit 5775b579 authored by Paul's avatar Paul
Browse files

Fix softmax operator

parent 311ae1d1
......@@ -413,11 +413,6 @@ struct atan : unary
std::string name() const { return "atan"; }
};
struct softmax : unary
{
std::string name() const { return "softmax"; }
};
struct tanh : unary
{
std::string name() const { return "tanh"; }
......@@ -433,6 +428,16 @@ struct neg : unary
std::string name() const { return "neg"; }
};
struct softmax
{
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
};
struct flatten
{
uint64_t axis = 0;
......
......@@ -53,7 +53,6 @@ struct onnx_parser
add_generic_op("MatMul", gemm{});
add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"});
add_generic_op("Softmax", softmax{});
add_generic_op("Sub", sub{});
add_generic_op("Sum", add{});
......@@ -65,6 +64,7 @@ struct onnx_parser
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
}
template <class F>
......@@ -101,6 +101,15 @@ struct onnx_parser
});
}
instruction_ref
parse_softmax(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(reshape{{long(dims[0]), 1, 1, long(dims[1])}}, args.front());
auto s = prog.add_instruction(softmax{}, r);
return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s);
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......
......@@ -12,6 +12,7 @@
namespace migraph {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_COMPILE)
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_TRACE_EVAL)
struct program_impl
{
......@@ -317,8 +318,18 @@ argument generic_eval(const program& p,
argument program::eval(std::unordered_map<std::string, argument> params) const
{
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
if(enabled(MIGRAPH_TRACE_EVAL{})) {
auto& ctx = this->impl->ctx;
return generic_eval(*this, this->impl->ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: " << ins->name() << std::endl;
return f();
});
} else {
return generic_eval(*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
}
double common_average(const std::vector<double>& v)
......
......@@ -312,7 +312,7 @@ struct miopen_softmax
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return inputs.at(1);
return op.compute_shape({inputs.at(0)});
}
argument
......@@ -438,8 +438,9 @@ struct miopen_apply
instruction_ref apply_softmax(instruction_ref ins)
{
auto&& op = any_cast<softmax>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_softmax{}, ins->inputs().at(0), output);
return prog->replace_instruction(ins, miopen_softmax{op}, ins->inputs().at(0), output);
}
instruction_ref apply_add(instruction_ref ins)
......
......@@ -246,6 +246,17 @@ struct test_softmax
}
};
struct test_softmax2
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
};
struct test_conv
{
migraph::program create_program() const
......@@ -527,6 +538,9 @@ int main()
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_softmax>();
// TODO: Add reshapes to make this a valid test case
// verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv_relu>();
verify_program<test_add_relu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment