Commit ab855464 authored by Khalique's avatar Khalique
Browse files

added pack op

parent dc85aa6b
......@@ -119,6 +119,7 @@ struct tf_parser
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax);
......@@ -353,6 +354,19 @@ struct tf_parser
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_pack(const std::string&, const attribute_map& attributes, std::vector<instruction_ref> args)
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
std::transform(args.begin(), args.end(), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg){ return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
......
.
0 Placeholder*
shape:*
dtype0
.
1 Placeholder*
dtype0*
shape:
.
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
......@@ -129,6 +129,24 @@ TEST_CASE(identity_test)
EXPECT(p == prog);
}
TEST_CASE(pack_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 1;
std::transform(args.begin(), args.end(), std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) { return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pooling_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment