Unverified Commit eacad500 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Parse clip issue (#625)

* fix a bug in parsing clip

* clang format

* add unit tests

* clang format
parent 42a97dfb
...@@ -442,14 +442,13 @@ struct onnx_parser ...@@ -442,14 +442,13 @@ struct onnx_parser
bool min_used = false; bool min_used = false;
bool max_used = false; bool max_used = false;
if(args.size() == 3) if(args.size() == 3 and args[2]->name() != "undefined")
{ {
min_arg = args[1];
max_arg = args[2]; max_arg = args[2];
min_used = true;
max_used = true; max_used = true;
} }
else if(args.size() == 2)
if(args.size() >= 2 and args[1]->name() != "undefined")
{ {
min_arg = args[1]; min_arg = args[1];
min_used = true; min_used = true;
...@@ -467,17 +466,31 @@ struct onnx_parser ...@@ -467,17 +466,31 @@ struct onnx_parser
} }
if(min_used) if(min_used)
{
min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg); min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);
}
if(max_used) if(max_used)
{
max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg); max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);
}
if(min_used and max_used) if(min_used and max_used)
{
return prog.add_instruction(make_op("clip"), args[0], min_arg, max_arg); return prog.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
if(min_used) }
else if(max_used)
{
return prog.add_instruction(make_op("min"), args[0], max_arg);
}
else if(min_used)
{
return prog.add_instruction(make_op("max"), args[0], min_arg); return prog.add_instruction(make_op("max"), args[0], min_arg);
}
return prog.add_instruction(make_op("identity"), args[0]); else
{
return prog.add_instruction(make_op("identity"), args[0]);
}
} }
instruction_ref parse_arg_op(const std::string&, instruction_ref parse_arg_op(const std::string&,
...@@ -2535,8 +2548,11 @@ struct onnx_parser ...@@ -2535,8 +2548,11 @@ struct onnx_parser
void parse_undefined(const std::string& name) void parse_undefined(const std::string& name)
{ {
auto ins = prog.add_instruction(op::undefined{}); if(!contains(instructions, name))
instructions[name] = ins; {
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
} }
static attribute_map get_attributes(const onnx::NodeProto& node) static attribute_map get_attributes(const onnx::NodeProto& node)
......
...@@ -378,6 +378,20 @@ def clip_test_op11(): ...@@ -378,6 +378,20 @@ def clip_test_op11():
return ([node], [x], [y], [min_val, max_val]) return ([node], [x], [y], [min_val, max_val])
@onnx_test
def clip_test_op11_max_only():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [0.0])
node = onnx.helper.make_node('Clip',
inputs=['0', '', 'max'],
outputs=['1'])
return ([node], [x], [y], [max_val])
@onnx_test @onnx_test
def clip_test_op11_min_only(): def clip_test_op11_min_only():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
...@@ -400,6 +414,16 @@ def clip_test_op11_no_args(): ...@@ -400,6 +414,16 @@ def clip_test_op11_no_args():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def clip_test_op11_no_args1():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3])
node = onnx.helper.make_node('Clip', inputs=['0', '', ''], outputs=['1'])
return ([node], [x], [y])
@onnx_test @onnx_test
def concat_test(): def concat_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3])
......
...@@ -301,6 +301,21 @@ TEST_CASE(clip_test) ...@@ -301,6 +301,21 @@ TEST_CASE(clip_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(clip_test_op11_max_only)
{
migraphx::program p;
auto max_val = p.add_literal(0.0f);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::undefined{});
max_val = p.add_instruction(migraphx::op::multibroadcast{{3}}, max_val);
auto r = p.add_instruction(migraphx::op::min{}, l0, max_val);
p.add_return({r});
auto prog = migraphx::parse_onnx("clip_test_op11_max_only.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test_op11) TEST_CASE(clip_test_op11)
{ {
migraphx::program p; migraphx::program p;
...@@ -337,6 +352,19 @@ TEST_CASE(clip_test_op11_no_args) ...@@ -337,6 +352,19 @@ TEST_CASE(clip_test_op11_no_args)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(clip_test_op11_no_args1)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
p.add_instruction(migraphx::op::undefined{});
auto r = p.add_instruction(migraphx::op::identity{}, l0);
p.add_return({r});
auto prog = migraphx::parse_onnx("clip_test_op11_no_args1.onnx");
EXPECT(p == prog);
}
TEST_CASE(concat_test) TEST_CASE(concat_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -204,10 +204,6 @@ def create_backend_test(testname=None, target_device=None): ...@@ -204,10 +204,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_softmax_default_axis_cpu') backend_test.exclude(r'test_softmax_default_axis_cpu')
# error cases # error cases
backend_test.exclude(r'test_clip_default_inbounds_cpu')
backend_test.exclude(r'test_clip_default_int8_inbounds_cpu')
backend_test.exclude(r'test_clip_default_int8_max_cpu')
backend_test.exclude(r'test_clip_default_max_cpu')
backend_test.exclude(r'test_constant_pad_cpu') backend_test.exclude(r'test_constant_pad_cpu')
backend_test.exclude(r'test_constantofshape_float_ones_cpu') backend_test.exclude(r'test_constantofshape_float_ones_cpu')
backend_test.exclude(r'test_constantofshape_int_shape_zero_cpu') backend_test.exclude(r'test_constantofshape_int_shape_zero_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment