Commit 87e0f7c2 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_onnx_gemm' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner

parents b5ebcc6b c45ba3d3
......@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
auto a_arg = args[0];
auto b_arg = args[1];
if(a_arg->get_shape().ndim() != 2 or b_arg->get_shape().ndim() != 2)
{
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(a_arg->get_shape().ndim()) + ", B is rank " +
std::to_string(b_arg->get_shape().ndim()));
}
float alpha = 1.0f;
float beta = 1.0f;
bool trans_a = false;
bool trans_b = false;
if(contains(info.attributes, "alpha"))
{
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
......@@ -53,61 +62,65 @@ struct parse_gemm : op_parser<parse_gemm>
}
if(contains(info.attributes, "transA"))
{
transa = parser.parse_value(info.attributes.at("transA")).at<bool>();
trans_a = parser.parse_value(info.attributes.at("transA")).at<bool>();
}
if(contains(info.attributes, "transB"))
{
transb = parser.parse_value(info.attributes.at("transB")).at<bool>();
trans_b = parser.parse_value(info.attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0];
auto dot_type = l1->get_shape().type();
std::vector<int64_t> perm = {1, 0};
auto dot_type = a_arg->get_shape().type();
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
if(l1->get_shape().type() != dot_type)
a_arg = info.add_broadcastable_binary_op("mul", alpha_literal, a_arg);
if(a_arg->get_shape().type() != dot_type)
{
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1);
a_arg =
info.add_instruction(make_op("convert", {{"target_type", dot_type}}), a_arg);
}
}
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
a_arg = (trans_a)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), a_arg)
: a_arg;
b_arg = (trans_b)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2);
auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3)
{
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
// TODO: support dynamic C input
if(std::any_of(args.cbegin(), args.cend(), [](auto in_arg) {
return in_arg->get_shape().dynamic();
}))
{
MIGRAPHX_THROW("PARSE_GEMM: C input not handled for dynamic input shapes");
}
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_arg = args[2];
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(beta_c->get_shape().type() != dot_type)
{
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3);
beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_c);
}
return info.add_instruction(make_op("add"), ret, beta_l3);
return info.add_instruction(make_op("add"), ret, beta_c);
}
}
......
......@@ -66,7 +66,7 @@ TEST_CASE(load_and_run_init_list)
TEST_CASE(quantize_fp16)
{
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
auto p1 = migraphx::parse_onnx("gemm_test.onnx");
const auto& p2 = p1;
const auto& p3 = p1;
migraphx::quantize_fp16(p1);
......@@ -82,7 +82,7 @@ TEST_CASE(quantize_fp16)
TEST_CASE(quantize_int8)
{
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
auto p1 = migraphx::parse_onnx("gemm_test.onnx");
const auto& p2 = p1;
auto t = migraphx::target("ref");
migraphx::quantize_int8_options options;
......
No preview for this file type
......@@ -2116,71 +2116,136 @@ def gathernd_batch_dims_test():
@onnx_test()
def gemm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [11, 5])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [])
a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [7, 11])
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['0', '1', '2'],
outputs=['3'],
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_no_C_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, 7])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [7, 11])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=2.0,
beta=2.0,
transA=1,
transB=1)
return ([node], [x, y, z], [a])
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_ex_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 7])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7])
def gemm_brcst_C_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [5, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_ex_brcst_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7])
def gemm_half_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT16, [8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT16, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT16, [6, 1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT16, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_half_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 1, 6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 7])
def gemm_dyn_inner_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [None, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [None, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
inputs=['A', 'B'],
outputs=['Y'],
alpha=0.5,
transA=1)
return ([node], [A, B], [Y])
@onnx_test()
def gemm_dyn_outer_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 11])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B'],
outputs=['Y'],
alpha=2.0,
transA=1,
transB=1)
return ([node], [A, B], [Y])
@onnx_test()
def gemm_dyn_C_error():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 7])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=1.0,
beta=1.0,
transA=1)
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_rank_error():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [4, 1, 8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [4, 1, 8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4, 1, 6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
return ([node], [A, B, C], [Y])
@onnx_test()
......
......@@ -2135,64 +2135,64 @@ TEST_CASE(gemm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {8, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_ex_test)
TEST_CASE(gemm_no_C_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_test.onnx");
auto prog = optimize_onnx("gemm_no_C_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_ex_brcst_test)
TEST_CASE(gemm_brcst_C_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 1}});
std::vector<std::size_t> out_lens{6, 7};
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
......@@ -2202,7 +2202,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
auto prog = optimize_onnx("gemm_brcst_C_test.onnx");
EXPECT(p == prog);
}
......@@ -2210,17 +2210,17 @@ TEST_CASE(gemm_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}});
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
std::vector<std::size_t> lens = {6, 7};
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(
......@@ -2236,6 +2236,60 @@ TEST_CASE(gemm_half_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_inner_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {6, 6, 0}}});
auto l1 = mm->add_parameter(
"B", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {7, 7, 0}}});
auto alpha = 0.5f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 8};
auto prog = migraphx::parse_onnx("gemm_dyn_inner_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_outer_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto alpha = 2.f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7};
auto prog = migraphx::parse_onnx("gemm_dyn_outer_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_C_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_dyn_C_error.onnx", options); }));
}
TEST_CASE(gemm_rank_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
}
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
......
......@@ -451,6 +451,94 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(gemm_test)
{
migraphx::program p = migraphx::parse_onnx("gemm_brcst_C_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape a_shape{migraphx::shape::float_type, {5, 6}};
std::vector<float> a_data = {0.26472837, 0.8525864, 0.41929847, 0.14151508, 0.43216065,
0.67468566, 0.42488748, 0.82021785, 0.9782456, 0.5794279,
0.6627283, 0.4790396, 0.9237051, 0.7340607, 0.67379653,
0.87168175, 0.37324256, 0.33278653, 0.42736676, 0.024699844,
0.75851107, 0.48719302, 0.5834426, 0.6938476, 0.43747696,
0.24054702, 0.26912406, 0.6760658, 0.5419149, 0.89949054};
migraphx::shape b_shape{migraphx::shape::float_type, {5, 7}};
std::vector<float> b_data = {
0.65727437, 0.54262096, 0.14126152, 0.8994123, 0.21831702, 0.81191784, 0.9371278,
0.3438551, 0.7121373, 0.90316695, 0.26614252, 0.80144906, 0.80301756, 0.49930334,
0.0719704, 0.63484156, 0.7343097, 0.32130218, 0.7094916, 0.6116475, 0.74144083,
0.021210382, 0.38724765, 0.44830495, 0.62347615, 0.022489505, 0.23316588, 0.76540905,
0.895689, 0.81540287, 0.223875, 0.9275573, 0.4621397, 0.70785195, 0.5658555};
migraphx::shape c_shape{migraphx::shape::float_type, {6, 1}};
std::vector<float> c_data = {
0.07358502, 0.13792239, 0.8574055, 0.40553397, 0.38205826, 0.62062204};
migraphx::parameter_map params;
params["A"] = migraphx::argument(a_shape, a_data.data());
params["B"] = migraphx::argument(b_shape, b_data.data());
params["C"] = migraphx::argument(c_shape, c_data.data());
auto result = p.eval(params).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.45261115, 0.83629227, 0.7533463, 0.7189715, 0.69160205, 0.824082, 0.9187499,
0.6659525, 0.96956736, 0.84293026, 0.8400868, 0.84835225, 1.0982862, 1.0642393,
1.1447254, 1.6184721, 1.6048342, 1.4741788, 1.4334437, 1.638659, 1.7428316,
0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702,
0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703,
1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(gemm_half_test)
{
migraphx::program p = migraphx::parse_onnx("gemm_half_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape a_shape{migraphx::shape::half_type, {8, 6}};
std::vector tmp = {0.2646, 0.8525, 0.4192, 0.1415, 0.4321, 0.675, 0.4248, 0.8203,
0.978, 0.5796, 0.6626, 0.479, 0.924, 0.734, 0.674, 0.8716,
0.3733, 0.3328, 0.4272, 0.0247, 0.7583, 0.4873, 0.5835, 0.694,
0.4375, 0.2406, 0.269, 0.6763, 0.542, 0.8994, 0.657, 0.5425,
0.1412, 0.8994, 0.2183, 0.812, 0.937, 0.3438, 0.712, 0.9033,
0.266, 0.8013, 0.803, 0.4993, 0.07196, 0.635, 0.7344, 0.3213};
std::vector<migraphx::half> a_data{tmp.cbegin(), tmp.cend()};
migraphx::shape b_shape{migraphx::shape::half_type, {8, 7}};
tmp = {0.7095, 0.612, 0.741, 0.02121, 0.3872, 0.4482, 0.6235, 0.02249, 0.2332, 0.7656,
0.8955, 0.8154, 0.2239, 0.9277, 0.4622, 0.708, 0.566, 0.0736, 0.138, 0.8574,
0.4055, 0.382, 0.6206, 0.424, 0.3674, 0.435, 0.998, 0.3594, 0.701, 0.6216,
0.01826, 0.6313, 0.514, 0.1095, 0.3203, 0.01636, 0.537, 0.01952, 0.4502, 0.8965,
0.5415, 0.7456, 0.793, 0.756, 0.9, 0.5264, 0.05368, 0.4214, 0.276, 0.1517,
0.08453, 0.83, 0.417, 0.1682, 0.845, 0.1729};
std::vector<migraphx::half> b_data{tmp.cbegin(), tmp.cend()};
migraphx::shape c_shape{migraphx::shape::half_type, {6, 1}};
tmp = {0.10846, 0.672, 0.527, 0.94, 0.429, 0.2291};
std::vector<migraphx::half> c_data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map params;
params["A"] = migraphx::argument(a_shape, a_data.data());
params["B"] = migraphx::argument(b_shape, b_data.data());
params["C"] = migraphx::argument(c_shape, c_data.data());
auto result = p.eval(params).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {1.071, 1.378, 1.465, 1.093, 0.968, 1.542, 1.145, 1.287, 1.533, 1.75, 1.338,
1.449, 1.592, 1.668, 1.265, 1.531, 1.656, 1.348, 1.2705, 1.525, 1.479, 1.754,
2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155,
1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(greaterorequal_test)
{
migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx");
......
......@@ -1027,7 +1027,7 @@ TEST_CASE(contiguous_dyn_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(conv_dynamic_batch_test)
TEST_CASE(conv_dyn_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment