Unverified Commit 6207ea00 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add reflect pad mode (#515)



* add reflect pad and tests

* formatting

* modify iterators

* formatting

* test tidy error

* fix algorithm

* formatting

* rename function

* move conditional

* formatting

* fix test

* fix left pad and tests
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent dc61b9ae
...@@ -452,6 +452,82 @@ struct onnx_parser ...@@ -452,6 +452,82 @@ struct onnx_parser
return ins; return ins;
} }
void calc_reflect_indices(std::vector<int>& indices, const int64_t num_dims)
{
int k = 0;
bool reversed = false;
// in reflect padding, if the num_pads > num_dims,
// compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0)
for(int& idx : indices)
{
if(k == num_dims - 1)
reversed = true;
if(k == 0)
reversed = false;
if(reversed)
k--;
else
k++;
idx = k;
}
}
instruction_ref reflect_pad(const std::vector<int64_t>& pads, instruction_ref input)
{
size_t num_dims = pads.size() / 2;
std::vector<int> ldims(pads.begin(), pads.begin() + num_dims);
std::vector<int> rdims(pads.begin() + num_dims, pads.end());
assert(ldims.size() == rdims.size());
std::vector<int64_t> axes(num_dims);
std::iota(axes.begin(), axes.end(), int64_t{0});
// iterate over dimensions, starting from lowest dimension
for(int64_t i = num_dims - 1; i >= 0; i--)
{
auto axis = i;
auto lcount = ldims.at(i);
auto rcount = rdims.at(i);
if(lcount == 0 and rcount == 0) // no padding for current dim
continue;
// calculate starts and ends for each iteration since shape may change
std::vector<size_t> dims = input->get_shape().lens();
std::vector<int64_t> starts(axes.size(), 0);
std::vector<int64_t> ends(dims.begin(), dims.end());
std::vector<instruction_ref> slices;
auto starts_it = starts.begin() + i;
auto ends_it = ends.begin() + i;
auto dims_it = dims.begin() + i;
std::vector<int> l_indices(lcount);
std::vector<int> r_indices(rcount);
// compute slice indices in a periodic fashion
calc_reflect_indices(l_indices, *dims_it);
calc_reflect_indices(r_indices, *dims_it);
for(int idx : l_indices)
{
*starts_it = idx;
*ends_it = *starts_it + 1;
slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input));
}
// when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end());
slices.push_back(input);
for(int idx : r_indices)
{
*starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1;
slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input));
}
input = prog.add_instruction(op::concat{axis}, slices);
}
return input;
}
template <class Op> template <class Op>
instruction_ref instruction_ref
parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args) parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
...@@ -1130,6 +1206,18 @@ struct onnx_parser ...@@ -1130,6 +1206,18 @@ struct onnx_parser
return prog.add_instruction(migraphx::op::identity{}, args.front()); return prog.add_instruction(migraphx::op::identity{}, args.front());
} }
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode == "reflect")
return reflect_pad(pads, args.front());
if(mode != "constant")
{
MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding");
}
}
float value = 0.0f; float value = 0.0f;
// third input is the value // third input is the value
if(args.size() == 3) if(args.size() == 3)
...@@ -1151,14 +1239,6 @@ struct onnx_parser ...@@ -1151,14 +1239,6 @@ struct onnx_parser
value = parse_value(info.attributes.at("value")).at<float>(); value = parse_value(info.attributes.at("value")).at<float>();
} }
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "constant")
{
MIGRAPHX_THROW("PARSE_PAD: migraphx currently only supports constant padding");
}
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front()); return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
} }
// Use a literal instruction to replace the shape since, output of // Use a literal instruction to replace the shape since, output of
......
...@@ -1476,6 +1476,52 @@ def pad_3arg_test(): ...@@ -1476,6 +1476,52 @@ def pad_3arg_test():
return ([arg_val, arg_pad, node], [x], [y]) return ([arg_val, arg_pad, node], [x], [y])
@onnx_test
def pad_reflect_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5])
sizes = np.array([0, 2, 0, 1])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
node = onnx.helper.make_node('Pad',
mode='reflect',
inputs=['0', 'arg_pad'],
outputs=['1'])
return ([arg_pad, node], [x], [y])
@onnx_test
def pad_reflect_multiaxis_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 5])
sizes = np.array([0, 2, 2, 0])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)
node = onnx.helper.make_node('Pad',
mode='reflect',
inputs=['0', 'arg_pad'],
outputs=['1'])
return ([arg_pad, node], [x], [y])
@onnx_test @onnx_test
def pow_test(): def pow_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
...@@ -1168,6 +1168,40 @@ TEST_CASE(pad_3arg_test) ...@@ -1168,6 +1168,40 @@ TEST_CASE(pad_3arg_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pad_reflect_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto l1 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 1}, {2, 2}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 1}}, l0);
auto l3 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {2, 1}}, l0);
auto r = p.add_instruction(migraphx::op::concat{1}, l2, l1, l0, l3);
p.add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pad_reflect_multiaxis_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
p.add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 2, 0}});
auto l1 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 1}, {2, 2}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 2}, {2, 3}}, l0);
auto l3 = p.add_instruction(migraphx::op::concat{1}, l2, l1, l0);
auto l4 = p.add_instruction(migraphx::op::slice{{0, 1}, {0, 0}, {1, 5}}, l3);
auto l5 = p.add_instruction(migraphx::op::slice{{0, 1}, {1, 0}, {2, 5}}, l3);
auto r = p.add_instruction(migraphx::op::concat{0}, l3, l4, l5);
p.add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_multiaxis_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(pow_test) TEST_CASE(pow_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment