"vscode:/vscode.git/clone" did not exist on "e5d21a2434022994a2d63b7354d74367c07df6b8"
Commit 87aed719 authored by Paul's avatar Paul
Browse files

Add miopen softmax

parent ee346df0
...@@ -28,10 +28,6 @@ struct unknown ...@@ -28,10 +28,6 @@ struct unknown
else else
return input.front(); return input.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
...@@ -57,7 +53,9 @@ struct onnx_parser ...@@ -57,7 +53,9 @@ struct onnx_parser
add_generic_op("MatMul", gemm{}); add_generic_op("MatMul", gemm{});
add_generic_op("Mul", mul{}); add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"}); add_generic_op("Relu", activation{"relu"});
add_generic_op("Softmax", softmax{});
add_generic_op("Sub", sub{}); add_generic_op("Sub", sub{});
add_generic_op("Sum", add{});
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
......
...@@ -305,6 +305,34 @@ struct miopen_relu ...@@ -305,6 +305,34 @@ struct miopen_relu
} }
}; };
struct miopen_softmax
{
softmax op;
std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return inputs.at(1);
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenSoftmaxForward(ctx.handle.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
};
struct miopen_apply struct miopen_apply
{ {
program* prog = nullptr; program* prog = nullptr;
...@@ -350,6 +378,10 @@ struct miopen_apply ...@@ -350,6 +378,10 @@ struct miopen_apply
{ {
check_shape(s, apply_batch_norm_inference(it)); check_shape(s, apply_batch_norm_inference(it));
} }
else if(it->name() == "softmax")
{
check_shape(s, apply_softmax(it));
}
} }
} }
...@@ -404,6 +436,13 @@ struct miopen_apply ...@@ -404,6 +436,13 @@ struct miopen_apply
return ins; return ins;
} }
instruction_ref apply_softmax(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_softmax{}, ins->inputs().at(0), output);
}
instruction_ref apply_add(instruction_ref ins) instruction_ref apply_add(instruction_ref ins)
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
......
...@@ -235,6 +235,17 @@ struct test_add_broadcast5 ...@@ -235,6 +235,17 @@ struct test_add_broadcast5
} }
}; };
struct test_softmax
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
};
struct test_conv struct test_conv
{ {
migraph::program create_program() const migraph::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment