Commit 4506cb47 authored by Paul's avatar Paul
Browse files

Fix tests

parent e954a29d
...@@ -161,7 +161,7 @@ struct tf_parser ...@@ -161,7 +161,7 @@ struct tf_parser
add_mem_op("ConcatV2", &tf_parser::parse_concat, false); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv, false); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
...@@ -396,7 +396,7 @@ struct tf_parser ...@@ -396,7 +396,7 @@ struct tf_parser
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
auto weights = to_kcxy(to_nchw(args[1])); auto weights = to_kcxy(args[1]);
std::vector<int64_t> new_weights_shape; std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape)); copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
...@@ -409,8 +409,7 @@ struct tf_parser ...@@ -409,8 +409,7 @@ struct tf_parser
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights); auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights);
return prog.add_instruction(op, {args[0], new_weights}); return prog.add_instruction(op, {args[0], new_weights});
} }
...@@ -444,9 +443,9 @@ struct tf_parser ...@@ -444,9 +443,9 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = args[1]->eval().get<int32_t>().to_vector(); auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{1, 2}; std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
if(axes == hw_axes and lens.size() == 4) if(axes == hw_axes and lens.size() == 4)
......
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, is_nhwc);
if (is_nhwc)
migraphx::run_passes(prog, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}, migraphx::eliminate_identity{}});
return prog;
}
TEST_CASE(add_test) TEST_CASE(add_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1); p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_tf("add_test.pb", false); auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -28,7 +40,7 @@ TEST_CASE(add_bcast_test) ...@@ -28,7 +40,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -51,7 +63,7 @@ TEST_CASE(batchnorm_test) ...@@ -51,7 +63,7 @@ TEST_CASE(batchnorm_test)
auto l4 = p.add_parameter("4", s0); auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals}); auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4); p.add_instruction(op, l0, l1, l2, l3, l4);
auto prog = migraphx::parse_tf("batchnorm_test.pb", true); auto prog = optimize_tf("batchnorm_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -65,7 +77,7 @@ TEST_CASE(biasadd_test) ...@@ -65,7 +77,7 @@ TEST_CASE(biasadd_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true); auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -83,7 +95,7 @@ TEST_CASE(concat_test) ...@@ -83,7 +95,7 @@ TEST_CASE(concat_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false); auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -92,7 +104,7 @@ TEST_CASE(const_test) ...@@ -92,7 +104,7 @@ TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f}); p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false); auto prog = optimize_tf("constant_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -111,10 +123,9 @@ TEST_CASE(conv_test) ...@@ -111,10 +123,9 @@ TEST_CASE(conv_test)
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2); p.add_instruction(op, l0, l2);
p.add_instruction(op, l0, l3); auto prog = optimize_tf("conv_test.pb", true);
auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -134,12 +145,11 @@ TEST_CASE(depthwiseconv_test) ...@@ -134,12 +145,11 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3); auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4); auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5); p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true); auto prog = optimize_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -149,7 +159,7 @@ TEST_CASE(identity_test) ...@@ -149,7 +159,7 @@ TEST_CASE(identity_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("identity_test.pb", false); auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -164,7 +174,7 @@ TEST_CASE(matmul_test) ...@@ -164,7 +174,7 @@ TEST_CASE(matmul_test)
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1); p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false); auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -178,10 +188,10 @@ TEST_CASE(mean_test) ...@@ -178,10 +188,10 @@ TEST_CASE(mean_test)
p.add_literal(l); p.add_literal(l);
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0); auto prog = optimize_tf("mean_test.pb", false);
auto prog = migraphx::parse_tf("mean_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -191,14 +201,11 @@ TEST_CASE(mean_test_nhwc) ...@@ -191,14 +201,11 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0); auto prog = optimize_tf("mean_test_nhwc.pb", true);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -210,7 +217,7 @@ TEST_CASE(mul_test) ...@@ -210,7 +217,7 @@ TEST_CASE(mul_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1); p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false); auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -232,7 +239,7 @@ TEST_CASE(pack_test) ...@@ -232,7 +239,7 @@ TEST_CASE(pack_test)
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false); auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -241,11 +248,14 @@ TEST_CASE(pack_test_nhwc) ...@@ -241,11 +248,14 @@ TEST_CASE(pack_test_nhwc)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2}; auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1; int64_t nchw_axis = 3;
std::transform(args.begin(), std::transform(args.begin(),
args.end(), args.end(),
...@@ -254,7 +264,7 @@ TEST_CASE(pack_test_nhwc) ...@@ -254,7 +264,7 @@ TEST_CASE(pack_test_nhwc)
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true); auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -272,8 +282,8 @@ TEST_CASE(pooling_test) ...@@ -272,8 +282,8 @@ TEST_CASE(pooling_test)
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
p.add_instruction(max_pool_op, l0); p.add_instruction(max_pool_op, l0);
p.add_instruction(avg_pool_op, l0); // p.add_instruction(avg_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true); auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -283,7 +293,7 @@ TEST_CASE(relu_test) ...@@ -283,7 +293,7 @@ TEST_CASE(relu_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::relu{}, l0); p.add_instruction(migraphx::op::relu{}, l0);
auto prog = migraphx::parse_tf("relu_test.pb", false); auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -293,7 +303,7 @@ TEST_CASE(relu6_test) ...@@ -293,7 +303,7 @@ TEST_CASE(relu6_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -306,7 +316,7 @@ TEST_CASE(reshape_test) ...@@ -306,7 +316,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions // in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0); p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false); auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -319,7 +329,7 @@ TEST_CASE(softmax_test) ...@@ -319,7 +329,7 @@ TEST_CASE(softmax_test)
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0); auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r); auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s); p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false); auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -329,7 +339,7 @@ TEST_CASE(squeeze_test) ...@@ -329,7 +339,7 @@ TEST_CASE(squeeze_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0); p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false); auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -341,18 +351,13 @@ TEST_CASE(stridedslice_test) ...@@ -341,18 +351,13 @@ TEST_CASE(stridedslice_test)
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1}; op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0); auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2; auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1); p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment