Unverified Commit 4ab38dde authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Conv weight fix (#692)



* change transpose func

* formatting

* fix tf file

* add tests, change broadcast

* formatting

* revert if statement

* add nonzero axis test

* formatting

* remove test and add test file
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 2d2e2bfe
......@@ -49,8 +49,6 @@ struct broadcast
if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
{
if(axis != 0)
MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_lens, std::move(bcast_strides)};
}
else
......
......@@ -60,9 +60,7 @@ struct tf_parser
instruction_ref to_kcxy(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
}
instruction_ref make_contiguous(instruction_ref ins) const
......@@ -470,7 +468,7 @@ struct tf_parser
op.padding[1] = padding[1];
}
}
return mm->add_instruction(op, {l0, to_kcxy(args[1])});
return mm->add_instruction(op, {l0, weights});
}
instruction_ref parse_depthwiseconv(const std::string&,
......
......@@ -346,10 +346,13 @@ TEST_CASE(broadcast)
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{1, lens},
input);
}
{
......
......@@ -145,6 +145,14 @@ def biasadd_test(g1):
tf.nn.bias_add(g1_input, g2_input, name='bias_add1')
@tf_test
def biasadd_scalar_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32, shape=(1, 1), name='0')
g2_const = tf.constant(1.0, tf.float32, shape=(1, ), name='1')
tf.nn.bias_add(g1_input, g2_const, name='bias_add1')
@tf_test
def cast_test(g1):
with g1.as_default():
......@@ -185,6 +193,23 @@ def conv_test(g1):
tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1')
@tf_test
def conv_nchw_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 3, 16, 16),
name='0')
g1_weights = tf.constant(value=1.0,
dtype=tf.float32,
shape=(3, 3, 3, 32),
name='1')
tf.nn.conv2d(g1_input,
g1_weights, [1, 1, 1, 1],
"SAME",
data_format='NCHW',
name='conv1')
@tf_test
def depthwiseconv_test(g1):
with g1.as_default():
......@@ -540,10 +565,12 @@ if __name__ == '__main__':
batchnorm_test()
batchnormv3_test()
biasadd_test()
biasadd_scalar_test()
cast_test()
concat_test()
const_test()
conv_test()
conv_nchw_test()
depthwiseconv_test()
expanddims_test()
gather_test()
......
......@@ -196,6 +196,23 @@ TEST_CASE(biasadd_test)
EXPECT(p == prog);
}
TEST_CASE(biasadd_scalar_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 1}};
uint64_t axis = 1;
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}, {0}}, {1.0}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l2);
auto prog = optimize_tf("biasadd_scalar_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(cast_test)
{
migraphx::program p;
......@@ -239,7 +256,7 @@ TEST_CASE(const_test)
EXPECT(p == prog);
}
TEST_CASE(conv_test)
migraphx::program create_conv()
{
migraphx::program p;
......@@ -258,7 +275,21 @@ TEST_CASE(conv_test)
op.dilation = {1, 1};
auto l2 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
mm->add_instruction(op, l0, l2);
auto prog = optimize_tf("conv_test.pb", true);
return p;
}
TEST_CASE(conv_test)
{
migraphx::program p = create_conv();
auto prog = optimize_tf("conv_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(conv_nchw_test)
{
migraphx::program p = create_conv();
auto prog = optimize_tf("conv_nchw_test.pb", false);
EXPECT(p == prog);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment