Unverified Commit 9879574a authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add tile op (#513)



* fix pad calc

* add tile op

* formatting

* add test

* formatting

* fix tidy

* formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 2074d756
...@@ -131,6 +131,7 @@ struct onnx_parser ...@@ -131,6 +131,7 @@ struct onnx_parser
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>); add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Split", &onnx_parser::parse_split); add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
...@@ -1885,6 +1886,26 @@ struct onnx_parser ...@@ -1885,6 +1886,26 @@ struct onnx_parser
return prog.add_instruction(op::add{}, l_mul, unsq_off_val); return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
} }
instruction_ref
parse_tile(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
std::vector<std::int64_t> repeats;
arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });
auto l0 = args[0];
for(int i = 0; i < repeats.size(); i++)
{
auto l1 = l0;
for(int j = 1; j < repeats[i]; j++)
{
l0 = prog.add_instruction(op::concat{i}, l0, l1);
}
}
return l0;
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -2150,6 +2150,30 @@ def tanh_test(): ...@@ -2150,6 +2150,30 @@ def tanh_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def tile_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [2, 4])
node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z'])
return ([node], [x, y], [z],
[helper.make_tensor('y', TensorProto.INT64, [2], [1, 2])])
@onnx_test
def tile_test_3x2():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [2])
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [6, 4])
node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z'])
return ([node], [x, y], [z],
[helper.make_tensor('y', TensorProto.INT64, [2], [3, 2])])
@onnx_test @onnx_test
def transpose_test(): def transpose_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 2, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 2, 3])
......
...@@ -1628,6 +1628,32 @@ TEST_CASE(tanh_test) ...@@ -1628,6 +1628,32 @@ TEST_CASE(tanh_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(tile_test)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {1, 2}});
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
p.add_instruction(migraphx::op::concat{1}, input, input);
auto prog = optimize_onnx("tile_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(tile_test_3x2)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {3, 2}});
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto l0 = p.add_instruction(migraphx::op::concat{0}, input, input);
auto l1 = p.add_instruction(migraphx::op::concat{0}, l0, input);
p.add_instruction(migraphx::op::concat{1}, l1, l1);
auto prog = optimize_onnx("tile_test_3x2.onnx");
EXPECT(p == prog);
}
TEST_CASE(transpose_test) TEST_CASE(transpose_test)
{ {
migraphx::program p; migraphx::program p;
......
 tile_test:d

x
yz"Tile tile_test* :ByZ
x


Z
y

b
z


B
 tile_test_3x2:h

x
yz"Tile tile_test_3x2* :ByZ
x


Z
y

b
z


B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment