Unverified Commit 63036c34 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #256 from ROCmSoftwarePlatform/depthwise_conv

Depthwise convolution
parents 767ca0cc 71efa874
......@@ -117,6 +117,7 @@ struct tf_parser
add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
......@@ -339,6 +340,62 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights});
}
instruction_ref parse_depthwiseconv(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
// weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
// out_channels is equal to the multiplier. Adjust by inserting a reshape and
// setting in_channels to 1
int64_t multiplier = new_weights_shape[0];
int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, weights);
return prog.add_instruction(op, {args[0], new_weights});
}
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......
......@@ -119,6 +119,30 @@ TEST_CASE(conv_test)
EXPECT(p == prog);
}
TEST_CASE(depthwiseconv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
std::vector<float> weight_data(3 * 3 * 3 * 1);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l3);
p.add_instruction(op, l0, l4);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(identity_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment