Unverified Commit d411aa69 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic ref dot operator (#1457)

Extends dot MIGX operator to handle dynamic input shapes
Only allow dot between two dynamic shapes that have exactly matching outer dimensions
Inner dimensions must also match correspondingly
Updates dot related tests
Change check_shapes to use shape.ndim()
ONNX parsers for GEMM and MatMult will be updated in a separate PR
parent 8e7d2efe
...@@ -198,7 +198,7 @@ struct check_shapes ...@@ -198,7 +198,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(not this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gemm.hpp> #include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,41 +39,69 @@ struct dot ...@@ -38,41 +39,69 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type().has(2); check_shapes{inputs, *this, true}.same_type().same_ndims().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(not std::all_of( if(not std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.ndim() >= 2; }))
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accepts operands with 2 or more dimensions ");
} }
if(a.dynamic() or b.dynamic())
// only handle the case that the batch size of a and b are the same {
auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend()))
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0])
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims};
}
else
{
// only handle the case that all the dimensions except the last two are the same
if(not std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static outer dimensions of A and B mismatch: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.ndim() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.ndim() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static inner dimensions do not match: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens}; return {t, out_lens};
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result = argument{output_shape}; argument result = argument{dyn_out.computed_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
......
...@@ -383,9 +383,9 @@ struct ref_gemm ...@@ -383,9 +383,9 @@ struct ref_gemm
std::string name() const { return "ref::dot"; } std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f); migemm(result, args[0], args[1], 1.0f, 0.0f);
return result; return result;
......
...@@ -467,6 +467,199 @@ TEST_CASE(deconvolution_shape) ...@@ -467,6 +467,199 @@ TEST_CASE(deconvolution_shape)
weights_3d); weights_3d);
} }
TEST_CASE(dot_ndim_error0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error4)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_outer_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_3D_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_3D_test_1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_3D_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_4D_test)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_static_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {8, 8, 0}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_static_mismatch_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5, 0}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_dyn_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {4, 5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, 5}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_mismatch_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_mismatch_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4, 0}, {5, 5, 0}, {2, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
...@@ -638,46 +831,6 @@ TEST_CASE(gather) ...@@ -638,46 +831,6 @@ TEST_CASE(gather)
} }
} }
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
}
TEST_CASE(get_tuple_elem_test) TEST_CASE(get_tuple_elem_test)
{ {
migraphx::shape s0{migraphx::shape::bool_type, {1, 1}}; migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
...@@ -1131,106 +1284,6 @@ TEST_CASE(lstm) ...@@ -1131,106 +1284,6 @@ TEST_CASE(lstm)
} }
} }
// 2 inputs arguments
TEST_CASE(matmul)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
}
TEST_CASE(multibroadcast) TEST_CASE(multibroadcast)
{ {
{ {
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
template <class T> template <class T>
void matmul_test() void dot_2d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -82,11 +82,11 @@ void matmul_test() ...@@ -82,11 +82,11 @@ void matmul_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(matmul_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
template <class T> template <class T>
void matmul_test_ex() void dot_4d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -133,10 +133,10 @@ void matmul_test_ex() ...@@ -133,10 +133,10 @@ void matmul_test_ex()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test_ex<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
TEST_CASE(matmul_mutli_dim_2) TEST_CASE(dot_3D_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2) ...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim_2_beta0) TEST_CASE(dot_3D_C_test0)
{ {
migraphx::program p; migraphx::program p;
...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_beta_0) TEST_CASE(dot_3D_C_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0) ...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(matmul_mutli_dim_2_3) TEST_CASE(dot_4D_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3) ...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim1_2_3) TEST_CASE(dot_4D_alpha_beta_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_3args) TEST_CASE(dot_4D_alpha_beta_C_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args) ...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_3args) TEST_CASE(dot_2D_C_test0)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args) ...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
} }
} }
TEST_CASE(matmul_vv_inner_product) TEST_CASE(dot_vv_inner_product)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
} }
} }
TEST_CASE(matmul_vm) TEST_CASE(dot_vm)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm) ...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
} }
} }
TEST_CASE(matmul_mv) TEST_CASE(dot_mv)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv) ...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
} }
} }
TEST_CASE(matmul_mm1) TEST_CASE(dot_mm1)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1) ...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
} }
} }
TEST_CASE(matmul_mm2) TEST_CASE(dot_mm2)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2) ...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
} }
} }
TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 5}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape, a.data());
params["b"] = migraphx::argument(b_shape, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 5, 3}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape0, a.data());
params["b"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment