#include #include #include #include #include namespace migraph { namespace gpu { template void generic_rocblas_gemm(shape::as, Ts&&... xs) { rocblas_sgemm(std::forward(xs)...); } template void generic_rocblas_gemm(shape::as, Ts&&... xs) { rocblas_dgemm(std::forward(xs)...); } template void generic_rocblas_gemm(shape::as, Ts&&... xs) { rocblas_hgemm(std::forward(xs)...); } template void generic_rocblas_gemm(shape::as, Ts&&...) { MIGRAPH_THROW("Type unsupported by rocblas"); } template T to_rocblas_type(T x) { return x; } rocblas_half to_rocblas_type(half x) { return reinterpret_cast(x); } shape miopen_gemm::compute_shape(const std::vector& inputs) const { check_shapes{inputs, *this}.has(3); return op.compute_shape({inputs.at(0), inputs.at(1)}); } argument miopen_gemm::compute(context& ctx, const shape& output_shape, const std::vector& args) const { float alpha = 1.0f; float beta = 0.0f; bool transa = args[0].get_shape().transposed(); bool transb = args[1].get_shape().transposed(); rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0]; rocblas_int ldc = args[2].get_shape().strides()[0]; rocblas_int m = output_shape.lens()[0]; rocblas_int n = output_shape.lens()[1]; rocblas_int k = args[0].get_shape().lens()[1]; output_shape.visit_type([&](auto as) { auto alpha_r = to_rocblas_type(as(alpha)); auto beta_r = to_rocblas_type(as(beta)); generic_rocblas_gemm(as, ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none, n, m, k, &alpha_r, args[1].implicit(), ldb, args[0].implicit(), lda, &beta_r, args[2].implicit(), ldc); }); return args[2]; } } // namespace gpu } // namespace migraph