Commit 13a9dcd9 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

extend the gemm operator to be more consitent with the numpy.matmul

parent 235a463f
...@@ -826,18 +826,12 @@ struct dot ...@@ -826,18 +826,12 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
// change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6}, // according to the specification of the numpy.matmul()
// which can be handled by numpy. as long as all previous // inputs with the shape dims more than 2 are acceptable
// dims are 1 except the last two dims, the two matrices // as long as dim values are the same in the two inputs
// are multipliable if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
if(std::any_of(a.lens().rbegin() + 2, a.lens().rend(), [](auto i) { return (i != 1); }))
{ {
MIGRAPHX_THROW("DOT: first matrix, dimensions before matrix dims must be 1"); MIGRAPHX_THROW("DOT: dim values mismatch");
}
if(std::any_of(b.lens().rbegin() + 2, b.lens().rend(), [](auto i) { return (i != 1); }))
{
MIGRAPHX_THROW("DOT: second matrix, dimensions before matrix dims must be 1");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.lens().size() - 2;
......
...@@ -93,7 +93,17 @@ template <class T> ...@@ -93,7 +93,17 @@ template <class T>
void migemm_impl( void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta) tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); auto lens = amat.get_shape().lens();
bool batch_mul = std::accumulate(lens.begin(), lens.end(), std::size_t{1}, std::multiplies<std::size_t>()) ==
(*lens.rbegin()) * (*(lens.rbegin() + 1));
if (batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
else
{
migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{});
}
} }
void migemm( void migemm(
......
...@@ -348,9 +348,9 @@ TEST_CASE(dot) ...@@ -348,9 +348,9 @@ TEST_CASE(dot)
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 7}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 7}},
migraphx::op::dot{}, migraphx::op::dot{},
s_m1, s_m1,
s_m2); s_m2);
...@@ -366,14 +366,14 @@ TEST_CASE(dot) ...@@ -366,14 +366,14 @@ TEST_CASE(dot)
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 6}}; migraphx::shape s_m1{migraphx::shape::float_type, {3, 1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {3, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); throws_shape(migraphx::op::dot{}, s_m1, s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {2, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {3, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2); throws_shape(migraphx::op::dot{}, s_m1, s_m2);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment