"ts/vscode:/vscode.git/clone" did not exist on "3ec26b40afd88b8255ebc74b46f475b77aa4d19b"
Commit 86e8a8e6 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into parse_mean_fix

parents 6c27b962 b93f5320
...@@ -108,6 +108,7 @@ struct tf_parser ...@@ -108,6 +108,7 @@ struct tf_parser
{ {
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
...@@ -117,6 +118,7 @@ struct tf_parser ...@@ -117,6 +118,7 @@ struct tf_parser
add_mem_op("ConcatV2", &tf_parser::parse_concat); add_mem_op("ConcatV2", &tf_parser::parse_concat);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul); add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
...@@ -339,6 +341,62 @@ struct tf_parser ...@@ -339,6 +341,62 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref parse_depthwiseconv(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
// weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
// out_channels is equal to the multiplier. Adjust by inserting a reshape and
// setting in_channels to 1
int64_t multiplier = new_weights_shape[0];
int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, weights);
return prog.add_instruction(op, {args[0], new_weights});
}
instruction_ref instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
:
0 Placeholder*
dtype0*
shape:

relu6Relu60*
T0"
\ No newline at end of file
...@@ -119,6 +119,30 @@ TEST_CASE(conv_test) ...@@ -119,6 +119,30 @@ TEST_CASE(conv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(depthwiseconv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
std::vector<float> weight_data(3 * 3 * 3 * 1);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l3);
p.add_instruction(op, l0, l4);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(identity_test) TEST_CASE(identity_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -229,6 +253,16 @@ TEST_CASE(relu_test) ...@@ -229,6 +253,16 @@ TEST_CASE(relu_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(relu6_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_tf("relu6_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(reshape_test) TEST_CASE(reshape_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment