Unverified Commit 94bda243 authored by Attila Dusnoki's avatar Attila Dusnoki Committed by GitHub
Browse files

Add axes (optional) input to Pad (#2178)

parent 52c74f0e
......@@ -115,65 +115,71 @@ struct parse_pad : op_parser<parse_pad>
{
std::vector<op_desc> operators() const { return {{"Pad"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
std::string parse_mode(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads{};
if(args.size() >= 2)
if(contains(info.attributes, "mode"))
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
auto mode = info.attributes.at("mode").s();
if(mode == "reflect")
{
if(args.front()->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported");
}
else if(contains(info.attributes, "pads"))
}
else if(mode != "constant")
{
auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding");
}
return mode;
}
else
{
MIGRAPHX_THROW("PARSE_PAD: pad must be available");
// default mode
return "constant";
}
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}
if(contains(info.attributes, "mode"))
std::vector<int64_t> parse_pads(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto mode = info.attributes.at("mode").s();
if(mode == "reflect")
{
if(args.front()->get_shape().dynamic())
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported");
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: `pads` input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
return reflect_pad(info, pads, args.front());
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes.at("pads").ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
if(mode != "constant")
else
{
MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding");
MIGRAPHX_THROW("PARSE_PAD: `pads` must be available");
}
return pads;
}
float parse_constant_value(const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
float value = 0.0f;
// third input is the value
if(args.size() == 3)
if(args.size() >= 3 and args.at(2)->get_shape().scalar())
{
auto val_ins = args.at(2);
if(not val_ins->can_eval())
{
MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
MIGRAPHX_THROW("PARSE_PAD: input `value` must be constant");
}
auto val_arg = val_ins->eval();
if(val_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
MIGRAPHX_THROW("PARSE_PAD: `value` should contain only one element");
}
value = val_arg.at<float>();
}
......@@ -181,6 +187,81 @@ struct parse_pad : op_parser<parse_pad>
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
return value;
}
std::vector<int64_t> parse_axes(const std::vector<instruction_ref>& args,
bool is_constant_mode) const
{
std::vector<int64_t> axes{};
// axes is 3rd or 4th, depending on constant mode
auto pos = is_constant_mode ? 4 : 3;
if(args.size() >= pos)
{
auto axes_arg = args.at(pos - 1)->eval();
check_arg_empty(axes_arg, "PARSE_PAD: variable `axes` input not supported");
axes_arg.visit([&](auto v) { axes.assign(v.begin(), v.end()); });
}
return axes;
}
std::vector<int64_t> calculate_pads_with_axes(const std::vector<int64_t>& pads,
const std::vector<int64_t>& axes,
size_t input_rank) const
{
size_t num_axes = axes.size();
if(num_axes * 2 != pads.size())
{
MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * "
"number of elements of axes");
}
std::vector<int64_t> new_pads(input_rank * 2);
for(size_t idx{0}; idx < num_axes; ++idx)
{
// axis can be negative
int64_t axis = axes[idx] < 0 ? input_rank + axes[idx] : axes[idx];
// pad format is x1_begin, x2_begin, ... , x3_end, x4_end
new_pads[axis] = pads[idx];
new_pads[axis + input_rank] = pads[idx + num_axes];
}
return new_pads;
}
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads = parse_pads(info, args);
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}
std::string mode = parse_mode(info, args);
bool is_constant_mode = mode == "constant";
float value = is_constant_mode ? parse_constant_value(parser, info, args) : 0.0f;
std::vector<int64_t> axes = parse_axes(args, is_constant_mode);
size_t input_rank = args.front()->get_shape().ndim();
if(not axes.empty())
{
pads = calculate_pads_with_axes(pads, axes, input_rank);
}
if(pads.size() != input_rank * 2)
{
MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * "
"input rank");
}
if(mode == "reflect")
{
return reflect_pad(info, pads, args.front());
}
return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front());
......
......@@ -5107,6 +5107,32 @@ def pad_test():
return ([node], [x], [y])
@onnx_test()
def pad_asym_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])
node = onnx.helper.make_node('Pad',
inputs=['0'],
pads=[0, 1, 0, 3, 0, 2, 0, 4],
outputs=['1'])
return ([node], [x], [y])
@onnx_test()
def pad_asym_invalid_pads_error_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])
node = onnx.helper.make_node('Pad',
inputs=['0'],
pads=[0, 1, 0, 3, 0, 2],
outputs=['1'])
return ([node], [x], [y])
@onnx_test()
def pad_3arg_test():
values = np.array([1])
......@@ -5139,6 +5165,129 @@ def pad_3arg_test():
return ([arg_val, arg_pad, node], [x], [y])
@onnx_test()
def pad_4arg_axes_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)
sizes = np.array([1, 3, 2, 4])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
axes = np.array([1, 3])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])
node = onnx.helper.make_node(
'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1'])
return ([arg_axes, arg_val, arg_pad, node], [x], [y])
@onnx_test()
def pad_4arg_invalid_axes_error_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)
sizes = np.array([1, 3, 2, 4])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
axes = np.array([1, 2, 3])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])
node = onnx.helper.make_node(
'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1'])
return ([arg_axes, arg_val, arg_pad, node], [x], [y])
@onnx_test()
def pad_4arg_neg_axes_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)
sizes = np.array([1, 3, 2, 4])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
axes = np.array([-3, -1])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])
node = onnx.helper.make_node(
'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1'])
return ([arg_axes, arg_val, arg_pad, node], [x], [y])
@onnx_test()
def pad_reflect_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
......@@ -5162,6 +5311,39 @@ def pad_reflect_test():
return ([arg_pad, node], [x], [y])
@onnx_test()
def pad_reflect_with_axes_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5])
sizes = np.array([2, 1])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
axes = np.array([1])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)
node = onnx.helper.make_node('Pad',
mode='reflect',
inputs=['0', 'arg_pad', 'arg_axes'],
outputs=['1'])
return ([arg_axes, arg_pad, node], [x], [y])
@onnx_test()
def pad_reflect_multiaxis_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3])
......
......@@ -4958,6 +4958,22 @@ TEST_CASE(pad_test)
EXPECT(p == prog);
}
TEST_CASE(pad_asym_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}}), l0);
auto prog = optimize_onnx("pad_asym_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_asym_invalid_pads_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_asym_invalid_pads_error_test.onnx"); }));
}
TEST_CASE(pad_3arg_test)
{
migraphx::program p;
......@@ -4974,6 +4990,51 @@ TEST_CASE(pad_3arg_test)
EXPECT(p == prog);
}
TEST_CASE(pad_4arg_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}});
// axes=[1,3]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 3}});
// constant_value=1
mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
// pads=[1,3,2,4]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}});
auto r = mm->add_instruction(
migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_4arg_axes_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_4arg_invalid_axes_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_4arg_invalid_axes_error_test.onnx"); }));
}
TEST_CASE(pad_4arg_neg_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}});
// axes=[-3,-1]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {-3, -1}});
// constant_value=1
mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
// pads=[1,3,2,4]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}});
auto r = mm->add_instruction(
migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_4arg_neg_axes_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_attr_dyn_test)
{
migraphx::program p;
......@@ -5032,6 +5093,27 @@ TEST_CASE(pad_reflect_test)
EXPECT(p == prog);
}
TEST_CASE(pad_reflect_with_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 1}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0);
auto l2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0);
auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_with_axes_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_reflect_multiaxis_test)
{
migraphx::program p;
......
 pad_reflect_with_axes_test:ä
3arg_axes"Constant*
value**Bpad_axes 
3arg_pad"Constant*
value**Bpad_size 
2
0
arg_pad
arg_axes1"Pad*
mode"reflect pad_reflect_with_axes_testZ
0


b
1


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment