Commit 7c453f85 authored by Khalique's avatar Khalique
Browse files

added leaky_relu.onnx file, added test for parsing onnx

parent ade3a03c
......@@ -56,6 +56,7 @@ struct onnx_parser
add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{});
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
......@@ -260,6 +261,19 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref parse_leaky_relu(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
float alpha = 0.01;
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::leaky_relu op{alpha};
return prog.add_instruction(op, args.front());
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -14,8 +14,8 @@ shape miopen_leaky_relu::compute_shape(const std::vector<shape>& inputs) const
}
argument miopen_leaky_relu::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
......
......@@ -129,12 +129,12 @@ struct miopen_apply
}
return ins;
}
instruction_ref apply_leaky_relu(instruction_ref ins)
{
auto&& op = any_cast<op::leaky_relu>(ins->get_operator());
auto ad = make_leaky_relu(op.alpha);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_leaky_relu{std::move(ad)}, ins->inputs().at(0), output);
......
......@@ -368,6 +368,17 @@ struct test_add_relu
}
};
struct test_leaky_relu
{
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_instruction(migraph::op::leaky_relu{0.01}, x);
return p;
}
};
struct test_conv_pooling
{
migraph::program create_program() const
......@@ -619,6 +630,7 @@ int main()
verify_program<test_conv2>();
verify_program<test_conv_relu>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>();
verify_program<test_gemm>();
// verify_program<test_gemm_ld>();
......
leaky_relu-example:R
"
01" LeakyRelu*
alpha
#<
test-modelZ
0

b
1

B
\ No newline at end of file
......@@ -88,10 +88,23 @@ void pytorch_conv_relu_maxpool_x2()
EXPECT(p == prog);
}
void leaky_relu_test()
{
migraph::program p;
float alpha = 0.01f;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {3}});
p.add_instruction(migraph::op::leaky_relu{alpha}, l0);
auto prog = migraph::parse_onnx("leaky_relu.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment