"docs/en/vscode:/vscode.git/clone" did not exist on "b091e4d29135b01b8a36c935304f0670628adffd"
Commit 850ff8a1 authored by charlie's avatar charlie
Browse files

Fix onnx tests

parent ca9019e2
...@@ -2074,18 +2074,16 @@ def gemm_dyn_inner_test(): ...@@ -2074,18 +2074,16 @@ def gemm_dyn_inner_test():
def gemm_dyn_outer_test(): def gemm_dyn_outer_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, None]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 11]) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 11])
node = onnx.helper.make_node('Gemm', node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'], inputs=['A', 'B'],
outputs=['Y'], outputs=['Y'],
alpha=2.0, alpha=2.0,
beta=2.0,
transA=1, transA=1,
transB=1) transB=1)
return ([node], [A, B, C], [Y]) return ([node], [A, B], [Y])
@onnx_test @onnx_test
...@@ -2102,7 +2100,7 @@ def gemm_dyn_C_error(): ...@@ -2102,7 +2100,7 @@ def gemm_dyn_C_error():
beta=1.0, beta=1.0,
transA=1) transA=1)
return ([node], [A, B], [Y]) return ([node], [A, B, C], [Y])
@onnx_test @onnx_test
......
...@@ -2101,9 +2101,9 @@ TEST_CASE(gemm_half_test) ...@@ -2101,9 +2101,9 @@ TEST_CASE(gemm_half_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {8, 6}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {8, 7}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {6, 1}}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
...@@ -2155,22 +2155,13 @@ TEST_CASE(gemm_dyn_outer_test) ...@@ -2155,22 +2155,13 @@ TEST_CASE(gemm_dyn_outer_test)
auto l0 = mm->add_parameter( auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}}); "A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f; auto alpha = 2.f;
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); mm->add_return({dot});
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
mm->add_return({ret});
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7}; options.default_dyn_dim_value = {5, 10, 7};
...@@ -6031,24 +6022,6 @@ TEST_CASE(transpose_test) ...@@ -6031,24 +6022,6 @@ TEST_CASE(transpose_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}});
std::vector<int64_t> perm{0, 3, 1, 2};
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({t0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(topk_attrk_test) TEST_CASE(topk_attrk_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment