#include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace cpu { template using matrix = blaze::CustomMatrix; // NOLINT template static auto make_mat(tensor_view x) { const auto& s = x.get_shape(); // assert(s.lens().size() == 2); std::size_t n_dims = s.lens().size(); std::size_t dim_0 = n_dims - 2; std::size_t dim_1 = n_dims - 1; if(s.transposed()) return matrix{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]}; return matrix{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]}; } template static void visit_mat(tensor_view x, F f) { auto mat = make_mat(x); if(x.get_shape().transposed()) f(blaze::trans(mat)); else f(mat); } template struct is_fast_gemm_type : std::false_type { }; template <> struct is_fast_gemm_type : std::true_type { }; template void migemm_impl(tensor_view cmat, tensor_view amat, tensor_view bmat, float alpha, float beta, std::true_type) { visit_mat(amat, [&](const auto& a) { visit_mat(bmat, [&](const auto& b) { auto c = make_mat(cmat); c = beta * c; if(alpha != 0.0) { c = c + alpha * a * b; } }); }); } template void migemm_impl(tensor_view cmat, tensor_view amat, tensor_view bmat, float alpha, float beta, std::false_type) { auto a_lens = amat.get_shape().lens(); auto b_lens = bmat.get_shape().lens(); auto c_lens = cmat.get_shape().lens(); std::size_t nc_dims = c_lens.size(); std::size_t na_dims = a_lens.size(); std::size_t nb_dims = b_lens.size(); auto k = a_lens[na_dims - 1]; assert(a_lens[na_dims - 1] == b_lens[nb_dims - 1]); assert(c_lens[nc_dims - 2] == a_lens[na_dims - 2]); assert(c_lens[nc_dims - 1] == b_lens[nb_dims - 1]); std::size_t a_len_diff = nc_dims - na_dims; std::size_t b_len_diff = nc_dims - nb_dims; std::vector a_idx(na_dims); std::vector b_idx(nb_dims); shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { std::transform(c_lens.begin() + a_len_diff, c_lens.end(), a_lens.begin(), a_idx.begin(), [&](auto i, auto j) { return (j == 1) ? 0 : i; }); std::transform(c_lens.begin() + b_len_diff, c_lens.end(), b_lens.begin(), b_idx.begin(), [&](auto i, auto j) { return (j == 1) ? 0 : i; }); double s = 0.0; dfor(k)([&](auto kk) { a_idx[na_dims - 1] = b_idx[nb_dims - 2] = kk; s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); }); cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; }); } template void migemm_impl( tensor_view cmat, tensor_view amat, tensor_view bmat, float alpha, float beta) { auto lens = cmat.get_shape().lens(); std::size_t num_matrices = std::accumulate( lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies()); if(num_matrices == 1) { migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type{}); } else { migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{}); } } void migemm( const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) { visit_all(c_arg, a_arg, b_arg)( [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); } } // namespace cpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx