/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace ref { template using matrix = blaze::CustomMatrix; // NOLINT template static auto make_mat(tensor_view x) { const auto& s = x.get_shape(); // assert(s.lens().size() == 2); std::size_t n_dims = s.lens().size(); std::size_t dim_0 = n_dims - 2; std::size_t dim_1 = n_dims - 1; if(s.transposed()) return matrix{x.data(), s.lens()[dim_1], s.lens()[dim_0], s.strides()[dim_1]}; return matrix{x.data(), s.lens()[dim_0], s.lens()[dim_1], s.strides()[dim_0]}; } template static void visit_mat(tensor_view x, F f) { auto mat = make_mat(x); if(x.get_shape().transposed()) f(blaze::trans(mat)); else f(mat); } template struct is_fast_gemm_type : std::false_type { }; template <> struct is_fast_gemm_type : std::true_type { }; template void migemm_impl( tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta, std::true_type) { visit_mat(amat, [&](const auto& a) { visit_mat(bmat, [&](const auto& b) { auto c = make_mat(cmat); c = beta * c; // This is a simple optimization to avoid // compute A * B if alpha is 0.0 if(alpha != 0.0) { c = c + alpha * a * b; } }); }); } template void migemm_impl( tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta, std::false_type) { std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t dim_0 = n_dims - 2; std::size_t dim_1 = n_dims - 1; auto k = amat.get_shape().lens()[dim_1]; assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); auto cs = cmat.get_shape(); par_for(cs.elements(), [&](auto i) { auto c_idx = cs.multi(i); auto a_idx = c_idx; auto b_idx = c_idx; double s = 0.0; dfor(k)([&](auto kk) { a_idx[dim_1] = b_idx[dim_0] = kk; s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); }); cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; }); } template void migemm_impl(tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta) { auto lens = amat.get_shape().lens(); bool batch_mul = std::accumulate( lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies()) == 1; if(batch_mul) { migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type{}); } else { migemm_impl(cmat, amat, bmat, alpha, beta, std::false_type{}); } } template void migemm_tpl( const argument& c_arg, const argument& a_arg, const argument& b_arg, F alpha, F beta) { visit_all(c_arg, a_arg, b_arg)( [&](auto cmat, auto amat, auto bmat) { migemm_impl(cmat, amat, bmat, alpha, beta); }); } void migemm( const argument& c_arg, const argument& a_arg, const argument& b_arg, float alpha, float beta) { migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); } void migemm(const argument& c_arg, const argument& a_arg, const argument& b_arg, int32_t alpha, int32_t beta) { migemm_tpl(c_arg, a_arg, b_arg, alpha, beta); } } // namespace ref } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx