"vscode:/vscode.git/clone" did not exist on "51884fc21412b1800bb85b28c0f5a0b651d23cef"
Commit 7c453f85 authored by Khalique's avatar Khalique
Browse files

added leaky_relu.onnx file, added test for parsing onnx

parent ade3a03c
...@@ -56,6 +56,7 @@ struct onnx_parser ...@@ -56,6 +56,7 @@ struct onnx_parser
add_generic_op("Sub", op::sub{}); add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{}); add_generic_op("Sum", op::add{});
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
...@@ -260,6 +261,19 @@ struct onnx_parser ...@@ -260,6 +261,19 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref parse_leaky_relu(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
float alpha = 0.01;
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::leaky_relu op{alpha};
return prog.add_instruction(op, args.front());
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -14,8 +14,8 @@ shape miopen_leaky_relu::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,8 +14,8 @@ shape miopen_leaky_relu::compute_shape(const std::vector<shape>& inputs) const
} }
argument miopen_leaky_relu::compute(context& ctx, argument miopen_leaky_relu::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
......
...@@ -129,12 +129,12 @@ struct miopen_apply ...@@ -129,12 +129,12 @@ struct miopen_apply
} }
return ins; return ins;
} }
instruction_ref apply_leaky_relu(instruction_ref ins) instruction_ref apply_leaky_relu(instruction_ref ins)
{ {
auto&& op = any_cast<op::leaky_relu>(ins->get_operator()); auto&& op = any_cast<op::leaky_relu>(ins->get_operator());
auto ad = make_leaky_relu(op.alpha); auto ad = make_leaky_relu(op.alpha);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_leaky_relu{std::move(ad)}, ins->inputs().at(0), output); ins, miopen_leaky_relu{std::move(ad)}, ins->inputs().at(0), output);
......
...@@ -368,6 +368,17 @@ struct test_add_relu ...@@ -368,6 +368,17 @@ struct test_add_relu
} }
}; };
struct test_leaky_relu
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::op::leaky_relu{0.01}, x);
return p;
}
};
struct test_conv_pooling struct test_conv_pooling
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -619,6 +630,7 @@ int main() ...@@ -619,6 +630,7 @@ int main()
verify_program<test_conv2>(); verify_program<test_conv2>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
// verify_program<test_gemm_ld>(); // verify_program<test_gemm_ld>();
......
leaky_relu-example:R
"
01" LeakyRelu*
alpha
#<
test-modelZ
0

b
1

B
\ No newline at end of file
...@@ -88,10 +88,23 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -88,10 +88,23 @@ void pytorch_conv_relu_maxpool_x2()
EXPECT(p == prog); EXPECT(p == prog);
} }
void leaky_relu_test()
{
migraph::program p;
float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {3}});
p.add_instruction(migraph::op::leaky_relu{alpha}, l0);
auto prog = migraph::parse_onnx("leaky_relu.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
pytorch_conv_relu_maxpool(); pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool(); pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2(); pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment