Commit 5f52c066 authored by Khalique's avatar Khalique
Browse files

add test case for depthwise_conv

parent e4f95696
......@@ -369,7 +369,6 @@ struct tf_parser
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
......@@ -381,8 +380,13 @@ struct tf_parser
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
// weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
// out_channels is equal to the multiplier. Adjust by inserting a reshape and
// setting in_channels to 1
int64_t multiplier = new_weights_shape[0];
int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels;
......
......@@ -119,6 +119,30 @@ TEST_CASE(conv_test)
EXPECT(p == prog);
}
TEST_CASE(depthwiseconv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
std::vector<float> weight_data(3 * 3 * 3 * 1);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::reshape{{3,1,3,3}}, l3);
p.add_instruction(op, l0, l4);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(identity_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment