Commit 985d5b1a authored by Khalique's avatar Khalique
Browse files

added nhwc case, addressed axis bounds checking

parent af00eea8
......@@ -363,11 +363,21 @@ struct tf_parser
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
std::transform(
args.begin(),
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("Error in protobuf: axis must be smaller than input size");
}
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
[&](instruction_ref arg) {
return prog.add_instruction(op::unsqueeze{{axis}}, arg);
});
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
}
......
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
:
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
......@@ -151,6 +151,28 @@ TEST_CASE(pack_test)
EXPECT(p == prog);
}
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1,2,1,1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1,2,1,1}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1,2,1,1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
}
TEST_CASE(pooling_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment