Unverified Commit 3bbd808e authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Improve parsing pad operator (#466)



* change parse operator function signature

* clang format

* add parsing the split operator

* clang format

* add parsing split operator

* make squeeze/unsqueeze inputs to standard shape

* change parsing pad to support opset 11 definition

* clang format

* add unit tests for the split operator

* clang format

* fix cppcheck error

* clang format

* update tests for multiple program outputs

* clang format

* fix cppcheck error

* clang format

* refine an error message

* add unit tests for the pad operator

* clang format

* fix review comments

* fix unit test error
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8a89b7f8
...@@ -1103,26 +1103,56 @@ struct onnx_parser ...@@ -1103,26 +1103,56 @@ struct onnx_parser
instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args) instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
std::vector<int64_t> pads{}; std::vector<int64_t> pads{};
float value = 0.0f; if(args.size() >= 2)
if(contains(info.attributes, "pads")) {
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{ {
auto&& pad_vals = info.attributes["pads"].ints(); auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end()); pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
} }
else
{
MIGRAPHX_THROW("PARSE_PAD: pad must be available");
}
// check if padding is actually being done (at least one value is nonzero) // check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{ {
return prog.add_instruction(migraphx::op::identity{}, args.front()); return prog.add_instruction(migraphx::op::identity{}, args.front());
} }
if(contains(info.attributes, "value"))
float value = 0.0f;
// third input is the value
if(args.size() == 3)
{
auto val_ins = args.at(2);
if(!val_ins->can_eval())
{
MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
}
auto val_arg = val_ins->eval();
if(val_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
}
value = val_arg.at<float>();
}
else if(contains(info.attributes, "value"))
{ {
value = parse_value(info.attributes.at("value")).at<float>(); value = parse_value(info.attributes.at("value")).at<float>();
} }
if(contains(info.attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
auto mode = info.attributes.at("mode").s(); auto mode = info.attributes.at("mode").s();
if(mode != "constant") if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding"); {
MIGRAPHX_THROW("PARSE_PAD: migraphx currently only supports constant padding");
}
} }
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front()); return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
} }
......
...@@ -1456,6 +1456,38 @@ def pad_test(): ...@@ -1456,6 +1456,38 @@ def pad_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def pad_3arg_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)
sizes = np.array([1, 1, 2, 2])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [5, 5])
node = onnx.helper.make_node('Pad',
inputs=['0', 'arg_pad', 'arg_val'],
outputs=['1'])
return ([arg_val, arg_pad, node], [x], [y])
@onnx_test @onnx_test
def pow_test(): def pow_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
...@@ -1123,6 +1123,20 @@ TEST_CASE(pad_test) ...@@ -1123,6 +1123,20 @@ TEST_CASE(pad_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pad_3arg_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
p.add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 1, 2, 2}});
auto r = p.add_instruction(migraphx::op::pad{{1, 1, 2, 2}, 1.0f}, l0);
p.add_return({r});
auto prog = migraphx::parse_onnx("pad_3arg_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_test) TEST_CASE(pow_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment