"docs/en_US/Tutorial/Installation.md" did not exist on "7d69e3d57c0fdc4d79f1ca1cf47162797baee955"
Commit 4d8a0e98 authored by charlie's avatar charlie
Browse files

Many added tests and changes

parent f6c31887
...@@ -39,6 +39,15 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -39,6 +39,15 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto A = args[0];
auto B = args[1];
if(A->get_shape().ndim() != 2 or B->get_shape().ndim() != 2)
{
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(A->get_shape().ndim()) + "B is rank " +
std::to_string(B->get_shape().ndim()));
}
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
bool transa = false; bool transa = false;
...@@ -60,54 +69,59 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -60,54 +69,59 @@ struct parse_gemm : op_parser<parse_gemm>
transb = parser.parse_value(info.attributes.at("transB")).at<bool>(); transb = parser.parse_value(info.attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm(args[0]->get_shape().lens().size()); std::vector<int64_t> perm(2);
std::iota(perm.begin(), perm.end(), int64_t{0}); std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0]; auto dot_type = A->get_shape().type();
auto dot_type = l1->get_shape().type();
if(alpha != 1.0f) if(alpha != 1.0f)
{ {
auto alpha_literal = info.add_literal(alpha); auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1); A = info.add_broadcastable_binary_op("mul", alpha_literal, A);
if(l1->get_shape().type() != dot_type)
if(A->get_shape().type() != dot_type)
{ {
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1); A = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), A);
} }
} }
l1 = A = (transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), A) : A;
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1; B = (transb) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
auto l2 = (transb) : args[1];
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2); auto ret = info.add_instruction(make_op("dot"), A, B);
if(args.size() == 3) if(args.size() == 3)
{ {
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0) // TODO: support dynamic C input
if(std::any_of(args.cbegin(), args.cend(), [](auto in_arg) {
return in_arg->get_shape().dynamic();
}))
{
MIGRAPHX_THROW("PARSE_GEMM: C input not handled for dynamic input shapes");
}
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0)
{ {
auto out_lens = l1->get_shape().lens(); auto out_lens = A->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back(); out_lens.back() = B->get_shape().lens().back();
auto l3 = args[2]; auto C = args[2];
auto l3_lens = l3->get_shape().lens(); auto C_lens = C->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(not std::equal(out_lens.begin(), out_lens.end(), C_lens.begin(), C_lens.end()))
{ {
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), C = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]); args[2]);
} }
auto beta_literal = info.add_literal(beta); auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal); auto beta_C = info.add_broadcastable_binary_op("mul", C, beta_literal);
if(beta_l3->get_shape().type() != dot_type) if(beta_C->get_shape().type() != dot_type)
{ {
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), beta_C = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3); beta_C);
} }
return info.add_instruction(make_op("add"), ret, beta_l3); return info.add_instruction(make_op("add"), ret, beta_C);
} }
} }
......
No preview for this file type
...@@ -1988,71 +1988,121 @@ def gathernd_batch_dims_test(): ...@@ -1988,71 +1988,121 @@ def gathernd_batch_dims_test():
@onnx_test @onnx_test
def gemm_test(): def gemm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, 6])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [11, 5]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, []) C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 7])
a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [7, 11]) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm', node = onnx.helper.make_node('Gemm',
inputs=['0', '1', '2'], inputs=['A', 'B', 'C'],
outputs=['3'], outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [A, B, C], [Y])
@onnx_test
def gemm_no_C_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, 7])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [7, 11])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=2.0, alpha=2.0,
beta=2.0, beta=2.0,
transA=1, transA=1,
transB=1) transB=1)
return ([node], [x, y, z], [a]) return ([node], [A, B, C], [Y])
@onnx_test @onnx_test
def gemm_ex_test(): def gemm_brcst_C_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 8, 6]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 8, 7]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [5, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 7]) C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7]) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm', node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'], inputs=['A', 'B', 'C'],
outputs=['y'], outputs=['Y'],
alpha=0.5, alpha=0.5,
beta=0.8, beta=0.8,
transA=1) transA=1)
return ([node], [m1, m2, m3], [y]) return ([node], [A, B, C], [Y])
@onnx_test @onnx_test
def gemm_ex_brcst_test(): def gemm_half_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT16, [8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT16, [8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 1]) C = helper.make_tensor_value_info('C', TensorProto.FLOAT16, [6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7]) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT16, [6, 7])
node = onnx.helper.make_node('Gemm', node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'], inputs=['A', 'B', 'C'],
outputs=['y'], outputs=['Y'],
alpha=0.5, alpha=0.5,
beta=0.8, beta=0.8,
transA=1) transA=1)
return ([node], [m1, m2, m3], [y]) return ([node], [A, B, C], [Y])
@onnx_test @onnx_test
def gemm_half_test(): def gemm_dyn_inner_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 8, 6]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [None, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 8, 7]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [None, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 1, 6, 1]) Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 7])
node = onnx.helper.make_node('Gemm', node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'], inputs=['A', 'B'],
outputs=['y'], outputs=['Y'],
alpha=0.5, alpha=0.5,
beta=0.8,
transA=1) transA=1)
return ([node], [m1, m2, m3], [y]) return ([node], [A, B], [Y])
@onnx_test
def gemm_dyn_outer_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 11])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=2.0,
beta=2.0,
transA=1,
transB=1)
return ([node], [A, B, C], [Y])
@onnx_test
def gemm_dyn_C_error():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 7])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=1.0,
beta=1.0,
transA=1)
return ([node], [A, B], [Y])
@onnx_test @onnx_test
......
...@@ -2026,64 +2026,64 @@ TEST_CASE(gemm_test) ...@@ -2026,64 +2026,64 @@ TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {8, 6}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto alpha = 2.f; auto alpha = 0.5f;
auto beta = 2.0f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_test.onnx"); auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_ex_test) TEST_CASE(gemm_no_C_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 0.5f; auto alpha = 2.f;
auto beta = 0.8f; auto beta = 2.0f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b); mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_test.onnx"); auto prog = optimize_onnx("gemm_no_C_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_ex_brcst_test) TEST_CASE(gemm_brcst_C_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}}); auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}}); auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}}); auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 1}});
std::vector<std::size_t> out_lens{1, 1, 6, 7}; std::vector<std::size_t> out_lens{6, 7};
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
...@@ -2093,7 +2093,7 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -2093,7 +2093,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx"); auto prog = optimize_onnx("gemm_brcst_C_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -2101,17 +2101,17 @@ TEST_CASE(gemm_half_test) ...@@ -2101,17 +2101,17 @@ TEST_CASE(gemm_half_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}}); auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto alpha = 0.5f; auto alpha = 0.5f;
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction( t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7}; std::vector<std::size_t> lens = {6, 7};
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction( l2 = mm->add_instruction(
...@@ -2127,6 +2127,64 @@ TEST_CASE(gemm_half_test) ...@@ -2127,6 +2127,64 @@ TEST_CASE(gemm_half_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_dyn_inner_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {6, 6, 0}}});
auto l1 = mm->add_parameter(
"B", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {7, 7, 0}}});
auto alpha = 0.5f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 8};
auto prog = migraphx::parse_onnx("gemm_dyn_inner_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_outer_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7};
auto prog = migraphx::parse_onnx("gemm_dyn_outer_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_C_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_dyn_C_error.onnx", options); }));
}
TEST_CASE(globalavgpool_test) TEST_CASE(globalavgpool_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment