Commit 87aed719 authored by Paul's avatar Paul
Browse files

Add miopen softmax

parent ee346df0
......@@ -28,10 +28,6 @@ struct unknown
else
return input.front();
}
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
......@@ -57,7 +53,9 @@ struct onnx_parser
add_generic_op("MatMul", gemm{});
add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"});
add_generic_op("Softmax", softmax{});
add_generic_op("Sub", sub{});
add_generic_op("Sum", add{});
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
......
......@@ -305,6 +305,34 @@ struct miopen_relu
}
};
struct miopen_softmax
{
softmax op;
std::string name() const { return "gpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return inputs.at(1);
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenSoftmaxForward(ctx.handle.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
};
struct miopen_apply
{
program* prog = nullptr;
......@@ -350,6 +378,10 @@ struct miopen_apply
{
check_shape(s, apply_batch_norm_inference(it));
}
else if(it->name() == "softmax")
{
check_shape(s, apply_softmax(it));
}
}
}
......@@ -404,6 +436,13 @@ struct miopen_apply
return ins;
}
instruction_ref apply_softmax(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_softmax{}, ins->inputs().at(0), output);
}
instruction_ref apply_add(instruction_ref ins)
{
auto output = insert_allocation(ins, ins->get_shape());
......
......@@ -235,6 +235,17 @@ struct test_add_broadcast5
}
};
struct test_softmax
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {5, 3, 4, 2}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
};
struct test_conv
{
migraph::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment