/* * The MIT License (MIT) * * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { // NOLINTNEXTLINE static const char* const ck_gemm_kernel = R"__migraphx__( #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" template using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; using BDataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F16; using ALayout = Row; using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; namespace migraphx { extern "C" { static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl // clang-format off //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; // clang-format on __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) { // GEMM shape ck::index_t M = 3840; ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t StrideA = 4096; ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, CDataType, AElementOp, BElementOp, CElementOp, ck::InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t, AElementOp, BElementOp, CElementOp, remove_reference_t, true>; kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c); } } } // namespace migraphx )__migraphx__"; struct ck_gemm_compiler : compiler { std::vector names() const { return {"ck_gemm"}; } operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { hip_compile_options options; auto out_s = inputs.back(); options.set_launch_params(v, compute_global_for(ctx, out_s.elements())); options.inputs = inputs; options.output = out_s; options.kernel_name = "ck_gemm_kernel"; options.virtual_inputs = inputs; return compile_hip_code_object(ck_gemm_kernel, options); } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const { return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); } }; } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx