"vscode:/vscode.git/clone" did not exist on "eb370577755e181c0b9a5dbd374633f2cdb21123"
Unverified Commit d1e945da authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add same padding mode for onnx (#456)



* fix pad calc

* add padding calc and test

* formatting

* made asym generic function

* formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 63d8e40a
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/pad_calc.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -302,6 +303,24 @@ struct onnx_parser ...@@ -302,6 +303,24 @@ struct onnx_parser
return curr_ins; return curr_ins;
} }
template <class Op>
void check_asym_padding(instruction_ref& ins,
std::vector<int64_t>& padding,
Op& op,
float pad_val = 0)
{
if(padding[0] != padding[2] || padding[1] != padding[3])
{
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
ins = prog.add_instruction(op::pad{padding, pad_val}, ins);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
instruction_ref parse_clip(const std::string&, instruction_ref parse_clip(const std::string&,
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -424,7 +443,8 @@ struct onnx_parser ...@@ -424,7 +443,8 @@ struct onnx_parser
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
Op op; Op op;
auto l0 = args[0]; auto l0 = args[0];
auto weights = args[1];
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
if(contains(attributes, "auto_pad")) if(contains(attributes, "auto_pad"))
...@@ -441,17 +461,7 @@ struct onnx_parser ...@@ -441,17 +461,7 @@ struct onnx_parser
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
} }
if(padding[0] != padding[2] || padding[1] != padding[3]) check_asym_padding(l0, padding, op);
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
...@@ -471,7 +481,19 @@ struct onnx_parser ...@@ -471,7 +481,19 @@ struct onnx_parser
if(s.find("SAME") != std::string::npos) if(s.find("SAME") != std::string::npos)
{ {
op.padding_mode = op::padding_mode_t::same; op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
std::vector<int64_t> padding(input_dims.size());
calculate_padding(
0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(
1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);
check_asym_padding(l0, padding, op);
} }
} }
if(contains(attributes, "group")) if(contains(attributes, "group"))
...@@ -618,27 +640,10 @@ struct onnx_parser ...@@ -618,27 +640,10 @@ struct onnx_parser
{ {
MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values"); MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
} }
if(padding[0] != padding[2] || padding[1] != padding[3]) float pad_val = 0;
{ if(op.mode == "max")
// insert zeros for pad op (args[0] has 4 dims) pad_val = std::numeric_limits<float>::lowest();
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; check_asym_padding(l0, padding, op, pad_val);
// MaxPool
if(op.mode == "max")
{
l0 = prog.add_instruction(
op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
}
// AveragePool
else
{
l0 = prog.add_instruction(op::pad{padding}, l0);
}
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
......
conv_autopad_same_test:»
J
0
12"Conv*
auto_pad"SAME *
dilations@@ *
strides@@ conv_autopad_same_testZ
0




 Z
1




b
2




 B
\ No newline at end of file
...@@ -492,6 +492,22 @@ def conv_autopad_fail_test(): ...@@ -492,6 +492,22 @@ def conv_autopad_fail_test():
return ([node], [x, y], [out]) return ([node], [x, y], [out])
@onnx_test
def conv_autopad_same_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 3, 3])
out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 32, 32])
node = onnx.helper.make_node('Conv',
inputs=['0', '1'],
outputs=['2'],
dilations=[1, 1],
strides=[1, 1],
auto_pad='SAME')
return ([node], [x, y], [out])
@onnx_test @onnx_test
def conv_bias_test(): def conv_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 32, 32])
......
...@@ -341,6 +341,20 @@ TEST_CASE(conv_autopad_fail_test) ...@@ -341,6 +341,20 @@ TEST_CASE(conv_autopad_fail_test)
EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); })); EXPECT(test::throws([&] { optimize_onnx("conv_autopad_fail_test.onnx"); }));
} }
TEST_CASE(conv_autopad_same_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
migraphx::op::convolution op;
op.padding = {1, 1};
op.padding_mode = migraphx::op::padding_mode_t::same;
p.add_instruction(op, l0, l1);
auto prog = optimize_onnx("conv_autopad_same_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(conv_bias_test) TEST_CASE(conv_bias_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment