Commit d5122475 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more test cases for the dot operator.

parent 626f8dc1
......@@ -845,6 +845,11 @@ struct dot
const shape& b = inputs.at(1);
auto t = a.type();
if (!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
......
......@@ -552,7 +552,9 @@ struct onnx_parser
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
......
......@@ -170,11 +170,7 @@ rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_hal
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
if(input_shapes.size() == 3)
{
auto c_shape = inputs[2];
check_shapes{{c_shape}}.not_broadcasted();
}
check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
}
......@@ -235,8 +231,7 @@ argument miopen_gemm::compute(context& ctx,
m * n,
num_matrices);
});
// device::add(ctx.get_stream().get(), args[3], args[2], args[3]);
return args[3];
}
......@@ -244,117 +239,48 @@ argument miopen_gemm::compute(context& ctx,
// vector inner product
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
if(output_shape.elements() == 1)
{
assert(args[0].get_shape().elements() == args[1].get_shape().elements());
float beta = 0.0f;
rocblas_int elem_num = static_cast<rocblas_int>(args[0].get_shape().elements());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
// the function generic_rocblas_dot is not stable, so have to
// call the gemm function instead. In the future, we may change
// to call generic_rocblas_dot when it is stable.
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
rocblas_operation_none,
rocblas_operation_none,
1,
1,
elem_num,
&alpha_r,
to_pointer(args[1]),
1,
to_pointer(args[0]),
elem_num,
&beta_r,
to_pointer(args[2]),
1);
});
}
// matrix * vector (b is a vector)
else if(b_lens.size() == 2 && b_lens.at(1) == 1)
{
bool transa = args[0].get_shape().transposed();
rocblas_int m = static_cast<rocblas_int>(a_lens[0]);
rocblas_int n = static_cast<rocblas_int>(a_lens[1]);
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
float beta = 0.0f;
assert(a_lens.back() == args[1].get_shape().elements());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemv(as,
ctx.get_stream().get_rocblas(),
transa ? rocblas_operation_transpose : rocblas_operation_none,
m,
n,
&alpha_r,
to_pointer(args[0]),
lda,
to_pointer(args[1]),
1,
&beta_r,
to_pointer(args[2]),
1);
});
}
// vector * matrix (a is a vector)
else if(a_lens.size() == 2 && a_lens.at(0) == 1)
{
bool transb = !args[1].get_shape().transposed();
rocblas_int ldb = args[1].get_shape().strides()[(transb ? 1 : 0)];
rocblas_int m = b_lens[0];
rocblas_int n = b_lens[1];
float beta = 0.0f;
assert(b_lens[0] == args[0].get_shape().elements());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemv(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
m,
n,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
1,
&beta_r,
to_pointer(args[2]),
1);
});
}
// batch matrix multiplication
else
{
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
float beta = 0.0f;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
float beta = 0.0f;
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (num_matrices == 1)
{
generic_rocblas_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
to_pointer(args[2]),
ldc);
}
else
{
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
......@@ -375,8 +301,8 @@ argument miopen_gemm::compute(context& ctx,
ldc,
m * n,
num_matrices);
});
}
}
});
return args[2];
}
......
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
template <class T>
void matmul_test()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test<float>)
TEST_CASE_REGISTER(matmul_test<double>)
template <class T>
void matmul_test_ex()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test_ex<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>)
TEST_CASE(matmul_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_beta0)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
std::vector<float> m3 = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 2, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_beta_0)
{
migraphx::program p;
std::vector<float> m1 = {
-0.76234141, 0.01368910, -0.86343423, -0.99465282, 0.76133268, 0.96507140};
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
std::vector<float> m2 = {-0.15933632,
-0.69594712,
-0.06198966,
-1.23905184,
-0.83672704,
-1.06971832,
-0.12272917,
1.07094116,
-0.08346820,
1.16820693,
-0.95700874,
0.24059691};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
std::vector<float> m3 = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(matmul_mutli_dim_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
-1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882,
0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693,
0.03852506, -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
0.08698452, 0.94974586, 0.40087323, -0.05448534, 0.03220677, -1.22494296,
0.97938472, -1.43714454, -0.80430904, -0.08098728, 0.31520301, 0.49642169,
-1.63471091, 0.34390096, 2.81292176, -0.22666528, 1.54559556, -1.51075762};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.33170529, 2.26325120, -0.50639461, 0.64802947, 0.44748888, 0.33768068,
-0.53621075, 0.34341460, 0.58742520, -1.13995790, -0.99322535, 0.35447353,
0.01977110, -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
-0.39112198, 1.42303608, -0.02853060, 1.52610164, 0.53540909, 0.75618998,
-0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803,
-1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_3args)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_3args)
{
{
migraphx::program p;
std::vector<float> a = {-0.86217194,
-1.04129542,
-0.64850364,
-0.97078327,
-0.40516386,
0.83136927,
0.37717502,
0.42271939,
1.10062165,
-0.92239359,
0.40403076,
-0.43935377};
std::vector<float> b = {0.76084386,
1.89201125,
1.73218067,
0.7148568,
-0.55578914,
0.05799101,
-1.24090721,
-0.51151978,
1.13255803,
0.21540723,
-1.10459009,
0.45580331};
std::vector<float> c = {-0.80473623,
0.35154171,
-2.73077756,
-0.09093885,
-1.88850472,
-0.03375556,
-0.41798276,
2.87368099,
2.11031439};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = p.add_literal(migraphx::literal{c_shape, c});
p.add_instruction(migraphx::op::dot{}, al, bl, cl);
std::vector<float> gold = {-1.60947,
0.703083,
-5.46156,
-0.181878,
-3.77701,
-0.0675112,
-0.835966,
5.74736,
4.22063};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_vv_inner_product)
{
{
migraphx::program p;
std::vector<float> a = {0.7481789,
0.02906279,
1.01193836,
1.60222907,
1.89135978,
0.30054158,
-0.4892588,
-0.27027533};
std::vector<float> b = {-0.25829116,
0.27908929,
-1.27888957,
0.21152361,
0.08593658,
0.52163899,
1.38343824,
-0.2342857};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, ual, ubl);
std::vector<float> gold = {-1.43461};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {0.7481789,
0.02906279,
1.01193836,
1.60222907,
1.89135978,
0.30054158,
-0.4892588,
-0.27027533};
std::vector<float> b = {-0.25829116,
0.27908929,
-1.27888957,
0.21152361,
0.08593658,
0.52163899,
1.38343824,
-0.2342857};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.32f;
p.add_instruction(migraphx::op::dot{alpha}, ual, ubl);
std::vector<float> gold = {-0.4590752};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_vm)
{
{
migraphx::program p;
std::vector<float> a = {1.49530002,
-0.07181969,
0.44593846,
-0.8645019,
0.52992304,
-0.4910338,
-2.12179422,
-0.45962977};
std::vector<float> b = {-0.06210242, 0.0187149, 1.47482984, -1.19590602, -0.45601701,
0.36934488, -0.83913193, 0.75350964, 0.80707019, 0.35923582,
-2.18480722, -0.85608682, 0.75849199, 0.49103473, -0.91329477,
-0.36364322, -0.69688937, 0.07165814, -0.15505523, 0.52221663,
-0.98631192, -0.37353654, -1.89818706, -0.87209739, -0.33942003,
0.11390353, 0.78181162, -0.18395337, -0.34743419, -0.08091231,
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, ual, bl);
std::vector<float> gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {1.49530002,
-0.07181969,
0.44593846,
-0.8645019,
0.52992304,
-0.4910338,
-2.12179422,
-0.45962977};
std::vector<float> b = {-0.06210242, 0.0187149, 1.47482984, -1.19590602, -0.45601701,
0.36934488, -0.83913193, 0.75350964, 0.80707019, 0.35923582,
-2.18480722, -0.85608682, 0.75849199, 0.49103473, -0.91329477,
-0.36364322, -0.69688937, 0.07165814, -0.15505523, 0.52221663,
-0.98631192, -0.37353654, -1.89818706, -0.87209739, -0.33942003,
0.11390353, 0.78181162, -0.18395337, -0.34743419, -0.08091231,
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f;
p.add_instruction(migraphx::op::dot{alpha}, ual, bl);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-1.7468318, -0.38900251, 1.00183915, 0.06016438, 0.08295905, 1.5830535};
std::vector<float> b = {
1.2459538, 0.39586199, -0.77035574, 0.22689828, 0.3289835, 1.02804361,
-0.22941113, -0.33940324, 0.80078249, 1.0319152, 0.80034948, -0.11631159,
0.36899208, -0.28506697, -1.2211584, -0.55678377, -0.3618498, 0.34857264,
-0.38700147, -0.43434611, 1.73029783, -0.71578372, 0.09777723, 0.06616614,
-1.66721186, -0.16046032, -1.64581663, 1.09373609, -0.14127692, -0.01938473,
-0.67310303, -1.56154787, -1.0665462, 0.68538535, -1.53920085, -0.35710272,
0.06887234, 0.17474616, 1.08194804, -0.19990148, -0.91149488, 0.95303646,
0.95448717, -0.49332393, -1.762213, -0.56571194, -1.69704968, -0.82798066,
0.65531872, 1.5007798, 0.99877355, 0.53386114, -0.88150609, -1.0756985,
0.50962511, -0.68019002, 0.1583068, 2.83988407, -1.10292457, 0.02126969,
0.21129951, 0.25690146, -1.6490316, 0.55261771, -1.70504303, -0.02870394,
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, bual, bl);
std::vector<float> gold = {1.22914,
-1.17896,
2.28596,
-0.345637,
-0.962362,
0.168508,
-0.947471,
-3.02458,
-3.80131,
1.38484,
-2.45019,
-1.35064};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-1.7468318, -0.38900251, 1.00183915, 0.06016438, 0.08295905, 1.5830535};
std::vector<float> b = {
1.2459538, 0.39586199, -0.77035574, 0.22689828, 0.3289835, 1.02804361,
-0.22941113, -0.33940324, 0.80078249, 1.0319152, 0.80034948, -0.11631159,
0.36899208, -0.28506697, -1.2211584, -0.55678377, -0.3618498, 0.34857264,
-0.38700147, -0.43434611, 1.73029783, -0.71578372, 0.09777723, 0.06616614,
-1.66721186, -0.16046032, -1.64581663, 1.09373609, -0.14127692, -0.01938473,
-0.67310303, -1.56154787, -1.0665462, 0.68538535, -1.53920085, -0.35710272,
0.06887234, 0.17474616, 1.08194804, -0.19990148, -0.91149488, 0.95303646,
0.95448717, -0.49332393, -1.762213, -0.56571194, -1.69704968, -0.82798066,
0.65531872, 1.5007798, 0.99877355, 0.53386114, -0.88150609, -1.0756985,
0.50962511, -0.68019002, 0.1583068, 2.83988407, -1.10292457, 0.02126969,
0.21129951, 0.25690146, -1.6490316, 0.55261771, -1.70504303, -0.02870394,
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{0.21f}, bual, bl);
std::vector<float> gold = {0.25812,
-0.247582,
0.480051,
-0.0725837,
-0.202096,
0.0353867,
-0.198969,
-0.635161,
-0.798275,
0.290817,
-0.514539,
-0.283635};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_mv)
{
{
migraphx::program p;
std::vector<float> a = {0.1612524,
0.61266466,
-0.19212896,
1.34228825,
-1.09746949,
0.4680955,
-0.431748,
-0.89791241,
-2.19078702,
-0.13767058,
-1.66105228,
-0.91834613,
0.59199744,
1.41967261,
0.76237423};
std::vector<float> b = {0.14365572, 0.23401411, -0.8970094, -0.12526676, -1.04703286};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, al, ubl);
std::vector<float> gold = {1.31982, 1.19022, -1.96062};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {0.1612524,
0.61266466,
-0.19212896,
1.34228825,
-1.09746949,
0.4680955,
-0.431748,
-0.89791241,
-2.19078702,
-0.13767058,
-1.66105228,
-0.91834613,
0.59199744,
1.41967261,
0.76237423};
std::vector<float> b = {0.14365572, 0.23401411, -0.8970094, -0.12526676, -1.04703286};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.3f;
p.add_instruction(migraphx::op::dot{alpha}, al, ubl);
std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
1.24593227, -0.84351316, 0.27882229, -0.42518484, -1.11391528, 0.59141834,
1.34198714, 2.25884063, -1.32093452, 0.44766336, -0.09306479, 0.47526699,
0.25858488, 1.30820392, 1.17186787, 0.31530864, -1.19159424, -0.24100903,
-1.03857886, 1.54453427, 0.05041654, 1.67108177, 0.965805, 0.52958924,
-1.61243992, 0.02941846, 0.77523836, 1.97963853, -2.51093596, 0.21882645,
-2.60193574, 1.1899952, 1.70883519, 0.94586745, 2.65002512, -1.42427102,
1.0143951, -1.34115312, 1.63833732, -1.46477355, 0.44014877, 0.58032696,
-1.63874372, -0.82834423, 1.81131778, -0.52393379, 1.16721943, 0.39488835,
0.23947128, -0.15733194, 0.19451158, 1.21315445, 0.44594897, 0.40809135,
-0.64252994, 0.7541716, -0.97203195, 0.69208485, 0.34350988, 0.9836842};
std::vector<float> b = {0.05013914, 1.39932885, 2.56616476, 1.02225623, -0.03977829};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
auto bubl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 1}}, ubl);
p.add_instruction(migraphx::op::dot{}, al, bubl);
std::vector<float> gold = {-0.792717,
6.33595,
2.61466,
-3.39322,
5.42485,
3.59084,
6.78139,
-0.360492,
-4.28998,
2.87146,
3.29447,
0.765651};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_mm1)
{
{
migraphx::program p;
std::vector<float> a = {
-0.49450006, -1.07431991, -0.02796692, -0.99631927, 0.20040449, -1.39709437,
-0.15695328, 0.08208373, -0.09746386, 0.77923021, -0.1849151, 0.14419043,
-0.25798175, -0.2504807, -1.11134383, -0.71030613, -0.20234025, 0.90229168,
0.62643053, -0.83512638, 1.66051254, 0.05941673, 0.73081559, 0.27111867,
0.55060745, 0.34999583, 1.02236619, 0.60178395, 1.49646162, 1.93255155,
-3.65357913, -1.38059906, -0.46302398, 0.19847152, 0.39785875, 1.47004861,
-1.24482133, -0.01954702, 0.36073898, 1.56055978, -0.10344603, -0.34283135,
-0.56482649, 1.80861249, -0.92268202, 0.94371182, -0.02373232, -0.75441145,
0.43325034, 0.4057425, -0.48844822, -0.36390512, 0.74110406, 1.25158366,
0.52196654, 1.43461691, -0.57530864, -0.66716206, -1.76516289, 0.96582849};
std::vector<float> b = {0.49899375,
-2.20168661,
1.08895066,
-0.01135643,
0.90570669,
-1.43550963,
-1.73033377,
0.21338776,
0.96962508,
0.38913968,
-0.32822861,
0.88222863,
0.93330718,
-1.24265228,
-1.62587164};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
0.182745, -4.21937, 1.77551, 1.50775, -2.60888, -2.32484,
-0.557691, 6.13527, -2.91743, 2.37836, -6.42584, 1.14979,
0.77227, 0.349659, 2.92759, 2.32384, -2.90664, 0.0527679,
-0.547761, -0.155467, 0.964619, 2.09133, -4.44281, -1.3864};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {-0.0309568,
-1.57294749,
-0.00768606,
1.5786921,
0.50519718,
0.10530702,
-0.05302112,
-0.06503757,
0.4079716,
0.0799132,
-0.82624962,
0.49341502};
std::vector<float> b = {
0.3664867, 0.24649534, 1.14728076, 1.09911548, -1.23711247, -0.49436419,
-0.67557879, -0.84180575, -1.09754376, 0.07807351, 0.74349043, -0.92084701,
0.50267885, 0.78709401, 0.80598159, -0.51269589, -0.40337193, 0.29457878,
1.25447301, -1.66251457, -1.54652239, -0.35067765, -0.5214464, -0.7866878,
1.11128573, 0.26927291, -0.0929818, 0.07523954, 0.3256776, -1.08617826,
0.89294253, -0.91007619, -2.42825765, -1.76805581, 1.08136334, -0.14521253,
-1.32061148, 0.60663124, -1.19835255, -0.98803563, -1.06927896, -0.51967419,
-0.98974639, 1.01287011, 1.34910394, 0.1203349, 0.67387452, -0.32447465,
1.15187449, -0.82253807, 0.22302433, 0.46434695, 0.319647, 1.56459445,
0.15664012, 0.03998102, 0.62981041, 0.11831296, 0.47824434, -0.93941882,
-0.34674036, 1.17071104, 0.59203806, 2.75817738, -0.69300013, 1.30971899,
-0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, bal, bl);
std::vector<float> gold = {
-1.61175, 3.11849, -0.703205, 0.331635, -0.00946922, 0.645626, 0.834069, 1.06409,
0.881037, 0.227628, -0.200308, -1.71836, 0.156255, 0.477222, 0.571363, -1.04543,
1.40524, 1.24201, -2.95083, 1.19352, 1.5008, 0.636987, 0.148256, -0.0231631,
-1.15079, 1.42139, 1.80996, 1.79259, 2.7192, 0.331902, -0.726565, 0.0963351,
-0.710558, 0.259424, -0.342345, -1.80522, -0.580476, 0.277368, -3.95582, 0.614823,
-0.415107, 0.305138, 0.435993, -0.107089, -0.767885, -4.00837, 1.09921, -2.02129,
0.109717, 0.618422, 0.438342, 0.29602, 2.00928, 0.420871};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
TEST_CASE(matmul_mm2)
{
{
migraphx::program p;
std::vector<float> a = {
-0.49450006, -1.07431991, -0.02796692, -0.99631927, 0.20040449, -1.39709437,
-0.15695328, 0.08208373, -0.09746386, 0.77923021, -0.1849151, 0.14419043,
-0.25798175, -0.2504807, -1.11134383, -0.71030613, -0.20234025, 0.90229168,
0.62643053, -0.83512638, 1.66051254, 0.05941673, 0.73081559, 0.27111867,
0.55060745, 0.34999583, 1.02236619, 0.60178395, 1.49646162, 1.93255155,
-3.65357913, -1.38059906, -0.46302398, 0.19847152, 0.39785875, 1.47004861,
-1.24482133, -0.01954702, 0.36073898, 1.56055978, -0.10344603, -0.34283135,
-0.56482649, 1.80861249, -0.92268202, 0.94371182, -0.02373232, -0.75441145,
0.43325034, 0.4057425, -0.48844822, -0.36390512, 0.74110406, 1.25158366,
0.52196654, 1.43461691, -0.57530864, -0.66716206, -1.76516289, 0.96582849};
std::vector<float> b = {-1.12211357, 1.74720423, 0.60382572, -0.61090125, -0.3315936,
0.30924675, -0.28906435, 0.64039247, -1.2822253, 0.55899286,
2.14013013, 1.00944809, 0.21660017, -0.75465098, 0.12097934,
-1.64006315, 0.43582108, -0.64348541, 0.43101069, 1.30191386,
1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427,
1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
-1.84001907, 3.51427391, 0.42425673, 0.0638482, 2.40210271, 1.50027643,
4.81988916, -3.63687142, -0.19101717, -4.92522092, -1.76377022, -3.58095615,
1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025,
0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821};
p.add_instruction(migraphx::op::dot{}, al, bbl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {-0.19276159, -1.2568421, -0.321242, 1.21471077, -0.4927751,
0.69446894, -0.1786371, -1.00763473, -0.10279314, 3.02931355,
1.08359235, -0.35190132, -0.00639111, 0.78989113, 1.23538029,
0.4590747, 0.17304142, 0.42512412, 0.21076913, -0.01724556,
-0.17763898, 0.12852236, -0.00459301, 1.34498824, 0.02907823,
0.1784464, -0.20790355, -0.52336699, 0.45804085, 1.06025801};
std::vector<float> b = {-1.12211357, 1.74720423, 0.60382572, -0.61090125, -0.3315936,
0.30924675, -0.28906435, 0.64039247, -1.2822253, 0.55899286,
2.14013013, 1.00944809, 0.21660017, -0.75465098, 0.12097934,
-1.64006315, 0.43582108, -0.64348541, 0.43101069, 1.30191386,
1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427,
1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3, 5}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, bal, bbl);
std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
1.91883003e+00, -2.89962196e-01, 2.76404142e+00, 1.50048102e+00, -6.29650347e-01,
1.48105185e+00, -3.71716505e-03, 8.80281500e-01, 2.50057585e+00, 1.29958508e+00,
5.63751779e-01, 2.25703781e-01, 1.30516919e+00, 8.32118386e-01, 2.44050864e-01,
-2.49748221e+00, -5.60803176e+00, -2.98919069e+00, -1.11429417e+00, -3.29675989e+00,
1.02442564e-01, -1.87659303e+00, -4.67302454e-01, 9.16189968e-01, -1.33537175e-01,
8.27398578e-01, 1.94406914e+00, -2.39250915e-01, -1.77062701e+00, -6.46239534e-01,
-7.95202750e-01};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-0.55248691, 0.70275958, 0.56967633, 0.88206033, -0.85088547, 0.05689149,
-0.20084703, 0.18024434, 1.0730491, 0.15913531, 0.93621628, 0.35072771,
1.28616952, 1.55384379, 0.30376261, -1.12356544, -0.64271552, -2.50703079,
-0.23994372, 0.8166084, 0.06542249, -0.17472336, -0.37665211, 0.16342699,
0.07645941, 0.65024333, -1.19883423, -0.40536776, -0.31132765, 0.78113691,
-0.16887638, 2.30797418, -0.36241233, 0.33552153, -1.05343996, -0.16909699,
-1.22608815, 1.64165613, 0.96260828, -0.16733976, 0.84211199, 1.31243813,
0.89258549, -0.48250384, -1.06005206, 1.37021342, -0.35658565, 0.26879188};
std::vector<float> b = {
0.17111129, -0.82134741, -1.58001178, -1.46759447, 0.31522514, -0.11567352,
-0.038978, -0.3601414, -0.84379876, 0.24848939, -0.37080544, 0.00838631,
1.51316241, 0.42385344, 2.06043846, 1.82348849, 1.07180434, 0.6567393,
1.41164561, 0.73091185, -0.33541302, -0.98082287, -0.06605479, 0.82219717,
-1.41619634, 0.51326658, 0.26916313, 0.79819769, 0.85583702, 0.07876046,
-0.42375545, -0.7758751, 1.14334296, -0.14211708, -1.54520411, -0.55244869,
-0.48478899, 0.10782164, -0.20879552, -0.99019754, 1.78783102, -1.31610052,
1.73510175, -0.48360172, 0.62367417, -1.34180545, -0.37512931, -1.50521357,
0.08383314, 0.76165608, -0.4961646, 0.95821311, -0.68407191, 0.48299435,
-0.24323988, 0.34793412, 0.37908669, 1.19083454, 1.30218795, -0.26731035,
-0.34544132, -0.09595373, 0.50951334, 0.48896956, 0.38753818, -0.4939919,
0.02352126, 0.42013764, 0.07027765, 0.21169851, -0.24411376, -1.77793736,
-0.88370924, 0.95294025, -0.08208804, -0.95943892, 0.30280474, 1.1967013,
-1.17700948, 0.29533973};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 4, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
std::vector<float> gold = {
1.22136035, 1.3765651, 2.0611395, 1.70445494, 1.8189619, 0.2509717,
0.88815736, 1.13837946, 1.37006127, -0.53617378, 0.45759693, -0.503786,
-0.10575749, -0.81715738, 2.56316255, 0.85812927, -0.53425671, 1.38147704,
2.57874755, -1.05591061, -1.42065674, -0.25412658, -2.14494165, -2.81045272,
0.27491485, -0.04229986, 0.10181043, -0.55680682, -0.07633866, 0.313767,
-0.28202571, -1.64696179, -0.50872733, -1.08935912, 0.94291084, -0.71792156,
0.82981387, 1.14797592, 3.13989358, -0.17507726, -0.63429162, -0.72241531,
-0.61459168, -0.52561056, 0.3309648, -0.46185697, -1.60586695, -0.98590829,
0.63012062, -0.25606052, -0.69419352, -1.78299913, -0.38572706, 1.92249442,
0.3884186, -0.48153048, 0.84932351, 0.67234919, -1.07821322, -0.01208216};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {
-0.55248691, 0.70275958, 0.56967633, 0.88206033, -0.85088547, 0.05689149,
-0.20084703, 0.18024434, 1.0730491, 0.15913531, 0.93621628, 0.35072771,
1.28616952, 1.55384379, 0.30376261, -1.12356544, -0.64271552, -2.50703079,
-0.23994372, 0.8166084, 0.06542249, -0.17472336, -0.37665211, 0.16342699,
0.07645941, 0.65024333, -1.19883423, -0.40536776, -0.31132765, 0.78113691,
-0.16887638, 2.30797418, -0.36241233, 0.33552153, -1.05343996, -0.16909699,
-1.22608815, 1.64165613, 0.96260828, -0.16733976, 0.84211199, 1.31243813,
0.89258549, -0.48250384, -1.06005206, 1.37021342, -0.35658565, 0.26879188};
std::vector<float> b = {-0.33734601, 0.66386073, 0.41425048, 0.40190389, -0.99645073,
-0.10017067, -0.58542118, 0.48636962, 0.06301405, 1.14669128,
-0.06526677, 0.23172741, -1.49693143, -0.44464233, -0.12775566,
-1.32038007, 1.1812471, 1.22362746, -0.49013843, 0.25339836,
1.31698705, 1.54256669, 0.11211132, -0.18005487, 0.36730145,
0.97705953, -0.18909084, 0.544932, 0.32891878, 0.64250015,
-0.41381398, 0.47402562, 1.22286761, 1.07573211, -0.92988077,
-0.36340925, -1.76152377, -0.96642674, -0.79231929, 0.11517073};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 4, 5}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
-0.43180359, 0.19639674, -0.33742881, 0.98443538, -0.9021272, 1.25043704,
-0.45038184, -0.14689614, -0.91749459, 3.49467934, 3.81336312, 2.4482385,
1.49649707, 1.05889193, -3.49343731, -2.06958956, -2.52082858, -1.61401519,
-1.52966956, 0.01191848, -0.33246613, -0.70641362, -0.60391255, 0.28083355,
0.52255496, -1.08655006, 1.64648546, 0.80344255, 0.71987865, -3.00960296,
2.02318221, 3.32785057, -1.13203844, 1.81235734, 0.38067585, -0.88086897,
1.38307367, 0.42677257, 0.83759966, -0.34827442, -1.45067092, 2.09599671,
1.92882983, -0.30996324, 2.19736278, 2.32389426, 2.36741832, 1.62253915,
0.26698225, -0.00741609, -2.53680983, -0.0679954, 0.04499683, 0.85354276};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -876,242 +876,6 @@ TEST_CASE(reshape_test)
}
}
template <class T>
void gemm_test()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>)
template <class T>
void gemm_test_ex()
{
migraphx::program p;
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<T> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(gemm_test_ex<float>)
TEST_CASE_REGISTER(gemm_test_ex<double>)
TEST_CASE(gemm_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
-1.93300070, 0.33902698, -0.45173527, -0.72283069, -0.17177134, 1.62199882,
0.87052847, 0.14989811, -0.88969184, -0.18131398, 0.72654339, -0.57123693,
0.03852506, -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
0.08698452, 0.94974586, 0.40087323, -0.05448534, 0.03220677, -1.22494296,
0.97938472, -1.43714454, -0.80430904, -0.08098728, 0.31520301, 0.49642169,
-1.63471091, 0.34390096, 2.81292176, -0.22666528, 1.54559556, -1.51075762};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.33170529, 2.26325120, -0.50639461, 0.64802947, 0.44748888, 0.33768068,
-0.53621075, 0.34341460, 0.58742520, -1.13995790, -0.99322535, 0.35447353,
0.01977110, -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
-0.39112198, 1.42303608, -0.02853060, 1.52610164, 0.53540909, 0.75618998,
-0.26877787, -1.90886366, 0.30622790, 0.59794535, 1.29795331, -0.37805803,
-1.58167176, -1.26966832, 0.27435891, 0.89430347, 0.22854926, -0.50317658};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
p.add_instruction(migraphx::op::dot{}, l1, l2);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
{
migraphx::program p;
std::vector<float> m1 = {
1.23636469, -0.47041261, -0.14375651, -0.48371852, 1.16479301, -0.89361055,
-0.18569086, 1.10700457, -1.02632638, 0.82277012, 0.33525769, 0.52825145,
-1.00141689, 0.45510090, -0.02675039, -0.60454439, 0.38551153, -0.01658514,
0.93059292, -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
-1.36200452, 2.25822236, -1.23416970, 0.12312496, 0.76232760, -0.83594234,
1.67418145, -0.19412936, 1.05261378, 0.66246074, -1.15233398, 0.16429736};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
std::vector<float> m2 = {
-0.87300530, -0.07112838, 0.19196860, -1.04986840, 1.20348200, 0.31966893,
1.04805440, -2.04777729, -0.67906052, -1.17250760, 0.34305044, -1.01957785,
-1.12694862, 0.18431338, -1.63712290, 0.27566931, -1.11282021, 1.41738919,
0.47871283, -1.01980420, 1.00212436, -0.78740444, -1.65636133, 1.51466547,
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
auto l3 = p.add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
auto l_beta = p.add_literal(beta);
auto b_beta = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
auto m3_beta = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(maxpool_test)
{
migraphx::program p;
......
......@@ -870,7 +870,7 @@ struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
}
};
struct gemm_mutli_dim_2
struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
{
migraphx::program create_program() const
{
......@@ -886,7 +886,127 @@ struct gemm_mutli_dim_2
}
};
struct gemm_mutli_dim_2_3
struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
};
struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
};
struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, bl1, bl2);
return p;
}
};
struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
};
struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
{
migraphx::program create_program() const
{
......@@ -902,6 +1022,182 @@ struct gemm_mutli_dim_2_3
}
};
struct gemm_2args_vv : verify_program<gemm_2args_vv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
float alpha = 0.23f;
auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sres);
return p;
}
};
struct gemm_2args_mv : verify_program<gemm_2args_mv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, ul2);
return p;
}
};
struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
p.add_instruction(migraphx::op::dot{}, l1, bul2);
return p;
}
};
struct gemm_2args_vm : verify_program<gemm_2args_vm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
return p;
}
};
struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
return p;
}
};
struct gemm_multi_3args : verify_program<gemm_multi_3args>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 1.0f;
float beta = 0.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.0f;
float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct test_contiguous : verify_program<test_contiguous>
{
migraphx::program create_program() const
......
matmul-example:{

1
2y"MatMul test_matmulZ
1



Z
2





b
y





B
\ No newline at end of file
matmul-example:_

1
2y"MatMul test_matmulZ
1



Z
2

b
y


B
\ No newline at end of file
matmul-example:W

1
2y"MatMul test_matmulZ
1


Z
2

b
y

B
\ No newline at end of file
matmul-example:_

1
2y"MatMul test_matmulZ
1

Z
2



b
y


B
\ No newline at end of file
matmul-example:W

1
2y"MatMul test_matmulZ
1

Z
2


b
y

B
\ No newline at end of file
matmul-example:S

1
2y"MatMul test_matmulZ
1

Z
2

b
y

B
\ No newline at end of file
......@@ -588,6 +588,113 @@ TEST_CASE(gemm_ex)
EXPECT(p == prog);
}
TEST_CASE(gemm_ex_brcst)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto t2 = p.add_instruction(migraphx::op::multibroadcast{out_lens}, l2);
auto alpha = 0.5f;
auto beta = 0.8f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, t2);
auto prog = migraphx::parse_onnx("gemm_test_ex1.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1);
auto sr0 = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sr0);
auto prog = migraphx::parse_onnx("matmul_vv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto prog = migraphx::parse_onnx("matmul_vm.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_vbm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
std::cout << "ONNX_TEST" << std::endl;
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
std::cout << "After Dot" << std::endl;
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_vbm.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_mv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res);
auto prog = migraphx::parse_onnx("matmul_mv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_bmv)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = migraphx::parse_onnx("matmul_bmv.onnx");
EXPECT(p == prog);
}
TEST_CASE(matmul_bmbm)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
auto prog = migraphx::parse_onnx("matmul_bmbm.onnx");
EXPECT(p == prog);
}
TEST_CASE(add_scalar_test)
{
migraphx::program p;
......
......@@ -371,21 +371,67 @@ TEST_CASE(logsoftmax)
}
}
TEST_CASE(dot)
// 2 inputs arguments
TEST_CASE(matmul)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {6, 1, 4}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
......@@ -403,45 +449,104 @@ TEST_CASE(dot)
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 7}},
migraphx::op::dot{},
s_m1,
s_m2);
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 7}},
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
}
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2);
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {3, 1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment