Unverified Commit 4ab38dde authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Conv weight fix (#692)



* change transpose func

* formatting

* fix tf file

* add tests, change broadcast

* formatting

* revert if statement

* add nonzero axis test

* formatting

* remove test and add test file
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 2d2e2bfe
...@@ -49,8 +49,6 @@ struct broadcast ...@@ -49,8 +49,6 @@ struct broadcast
if(std::all_of( if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; })) broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0)
MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_lens, std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
else else
......
...@@ -60,9 +60,7 @@ struct tf_parser ...@@ -60,9 +60,7 @@ struct tf_parser
instruction_ref to_kcxy(instruction_ref ins) const instruction_ref to_kcxy(instruction_ref ins) const
{ {
if(should_transpose(ins)) return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
} }
instruction_ref make_contiguous(instruction_ref ins) const instruction_ref make_contiguous(instruction_ref ins) const
...@@ -470,7 +468,7 @@ struct tf_parser ...@@ -470,7 +468,7 @@ struct tf_parser
op.padding[1] = padding[1]; op.padding[1] = padding[1];
} }
} }
return mm->add_instruction(op, {l0, to_kcxy(args[1])}); return mm->add_instruction(op, {l0, weights});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
......
...@@ -346,10 +346,13 @@ TEST_CASE(broadcast) ...@@ -346,10 +346,13 @@ TEST_CASE(broadcast)
migraphx::op::broadcast{0, lens}, migraphx::op::broadcast{0, lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
throws_shape(migraphx::op::broadcast{1, lens}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{1, lens},
input);
} }
{ {
......
...@@ -145,6 +145,14 @@ def biasadd_test(g1): ...@@ -145,6 +145,14 @@ def biasadd_test(g1):
tf.nn.bias_add(g1_input, g2_input, name='bias_add1') tf.nn.bias_add(g1_input, g2_input, name='bias_add1')
@tf_test
def biasadd_scalar_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32, shape=(1, 1), name='0')
g2_const = tf.constant(1.0, tf.float32, shape=(1, ), name='1')
tf.nn.bias_add(g1_input, g2_const, name='bias_add1')
@tf_test @tf_test
def cast_test(g1): def cast_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -185,6 +193,23 @@ def conv_test(g1): ...@@ -185,6 +193,23 @@ def conv_test(g1):
tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1') tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1')
@tf_test
def conv_nchw_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 3, 16, 16),
name='0')
g1_weights = tf.constant(value=1.0,
dtype=tf.float32,
shape=(3, 3, 3, 32),
name='1')
tf.nn.conv2d(g1_input,
g1_weights, [1, 1, 1, 1],
"SAME",
data_format='NCHW',
name='conv1')
@tf_test @tf_test
def depthwiseconv_test(g1): def depthwiseconv_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -540,10 +565,12 @@ if __name__ == '__main__': ...@@ -540,10 +565,12 @@ if __name__ == '__main__':
batchnorm_test() batchnorm_test()
batchnormv3_test() batchnormv3_test()
biasadd_test() biasadd_test()
biasadd_scalar_test()
cast_test() cast_test()
concat_test() concat_test()
const_test() const_test()
conv_test() conv_test()
conv_nchw_test()
depthwiseconv_test() depthwiseconv_test()
expanddims_test() expanddims_test()
gather_test() gather_test()
......
...@@ -196,6 +196,23 @@ TEST_CASE(biasadd_test) ...@@ -196,6 +196,23 @@ TEST_CASE(biasadd_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(biasadd_scalar_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
uint64_t axis = 1;
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l2);
auto prog = optimize_tf("biasadd_scalar_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(cast_test) TEST_CASE(cast_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -239,7 +256,7 @@ TEST_CASE(const_test) ...@@ -239,7 +256,7 @@ TEST_CASE(const_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(conv_test) migraphx::program create_conv()
{ {
migraphx::program p; migraphx::program p;
...@@ -258,7 +275,21 @@ TEST_CASE(conv_test) ...@@ -258,7 +275,21 @@ TEST_CASE(conv_test)
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1); auto l2 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
mm->add_instruction(op, l0, l2); mm->add_instruction(op, l0, l2);
auto prog = optimize_tf("conv_test.pb", true); return p;
}
TEST_CASE(conv_test)
{
migraphx::program p = create_conv();
auto prog = optimize_tf("conv_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(conv_nchw_test)
{
migraphx::program p = create_conv();
auto prog = optimize_tf("conv_nchw_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment