Commit f21a7346 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into propogate-constant

parents 68858a5b 369cb3a5
This diff is collapsed.
...@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test) ...@@ -15,7 +15,7 @@ TEST_CASE(pytorch_conv_bias_test)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
p.add_instruction(migraphx::op::add{}, l3, l4); p.add_instruction(migraphx::op::add{}, l3, l4);
auto prog = migraphx::parse_onnx("conv.onnx"); auto prog = migraphx::parse_onnx("conv.onnx");
...@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool) ...@@ -30,7 +30,7 @@ TEST_CASE(pytorch_conv_relu_maxpool)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool) ...@@ -52,7 +52,7 @@ TEST_CASE(pytorch_conv_bn_relu_maxpool)
auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}}); auto p6 = p.add_parameter("6", {migraphx::shape::float_type, {1}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6); auto l6 = p.add_instruction(migraphx::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraphx::op::relu{}, l6); auto l7 = p.add_instruction(migraphx::op::relu{}, l6);
...@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2) ...@@ -70,7 +70,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}}); auto l2 = p.add_parameter("2", {migraphx::shape::float_type, {5}});
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1); auto l3 = p.add_instruction(migraphx::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape()}, l2); auto l4 = p.add_instruction(migraphx::op::broadcast{axis, l3->get_shape().lens()}, l2);
auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4); auto l5 = p.add_instruction(migraphx::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraphx::op::relu{}, l5); auto l6 = p.add_instruction(migraphx::op::relu{}, l5);
auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
...@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2) ...@@ -78,7 +78,7 @@ TEST_CASE(pytorch_conv_relu_maxpool_x2)
auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraphx::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraphx::shape::float_type, {1}});
auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8); auto l10 = p.add_instruction(migraphx::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape()}, l9); auto l11 = p.add_instruction(migraphx::op::broadcast{axis, l10->get_shape().lens()}, l9);
auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11); auto l12 = p.add_instruction(migraphx::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraphx::op::relu{}, l12); auto l13 = p.add_instruction(migraphx::op::relu{}, l12);
p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13); p.add_instruction(migraphx::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
...@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test) ...@@ -108,9 +108,9 @@ TEST_CASE(imagescaler_test)
auto scale_val = p.add_literal(0.5f); auto scale_val = p.add_literal(0.5f);
auto bias_vals = p.add_literal( auto bias_vals = p.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraphx::op::scalar{s.lens()}, scale_val);
auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor); auto img_scaled = p.add_instruction(migraphx::op::mul{}, l0, scaled_tensor);
auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, bias_vals);
p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
auto prog = migraphx::parse_onnx("imagescaler_test.onnx"); auto prog = migraphx::parse_onnx("imagescaler_test.onnx");
...@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test) ...@@ -338,7 +338,7 @@ TEST_CASE(add_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_onnx("add_bcast_test.onnx"); auto prog = migraphx::parse_onnx("add_bcast_test.onnx");
...@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test) ...@@ -365,7 +365,7 @@ TEST_CASE(sub_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2); p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx"); auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
......
...@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast) ...@@ -229,6 +229,36 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::op::broadcast{2, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::broadcast{2, lens}, input);
}
}
TEST_CASE(gather) TEST_CASE(gather)
{ {
{ {
......
2
0 Placeholder*
dtype0*
shape
:
2
1 Placeholder*
dtype0*
shape
:
F
matmul1MatMul01*
T0*
transpose_a(*
transpose_b("
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

mul1Mul01*
T0"
\ No newline at end of file
...@@ -63,7 +63,7 @@ TEST_CASE(biasadd_test) ...@@ -63,7 +63,7 @@ TEST_CASE(biasadd_test)
uint64_t axis = 1; uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0); auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true); auto prog = migraphx::parse_tf("biasadd_test.pb", true);
...@@ -129,6 +129,33 @@ TEST_CASE(identity_test) ...@@ -129,6 +129,33 @@ TEST_CASE(identity_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(mul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test) TEST_CASE(pack_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -237,4 +264,28 @@ TEST_CASE(squeeze_test) ...@@ -237,4 +264,28 @@ TEST_CASE(squeeze_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(stridedslice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment