"...composable_kernel.git" did not exist on "66edb2590d47a4dd4208e10998b19d0318b1cd71"
Commit 4a96c2e4 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 70e8cc76
...@@ -11,9 +11,8 @@ ...@@ -11,9 +11,8 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
//#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
//#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
//#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp" #include "device_gemm_xdl_cshuffle_v2.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
...@@ -53,12 +52,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -53,12 +52,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 1-stage prefetch // // 1-stage prefetch
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch // // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 1 #elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -67,7 +66,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -67,7 +66,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// all thread // all thread
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 0, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 0, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 1 #elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -82,7 +81,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl ...@@ -82,7 +81,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; // < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
#endif #endif
......
...@@ -317,13 +317,14 @@ struct DeviceGemmXdl ...@@ -317,13 +317,14 @@ struct DeviceGemmXdl
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(has_main_k_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_v2r3< const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm, GridwiseGemm,
......
...@@ -462,13 +462,14 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -462,13 +462,14 @@ struct DeviceGemm_Xdl_CShuffle
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(has_main_k_block_loop)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
......
...@@ -464,13 +464,14 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -464,13 +464,14 @@ struct DeviceGemm_Xdl_CShuffle_v2
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(has_main_k_block_loop)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v2< const auto kernel = kernel_gemm_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
......
...@@ -17,6 +17,88 @@ namespace ck { ...@@ -17,6 +17,88 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
#else
ignore = gemm_desc_;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
......
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP #pragma once
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#include "common_header.hpp" #include "common_header.hpp"
namespace ck { namespace ck {
template <typename AGridDesc, template <index_t NumPrefetch>
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t NumPrefetch,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1; struct GridwiseGemmPipeline_v1;
// 1-stage prefetch // 1-stage prefetch
template <typename AGridDesc, template <>
typename ABlockDesc, struct GridwiseGemmPipeline_v1<1>
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1<AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
1,
HasMainLoop>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static __device__ void Run(const AGridDesc& a_grid_desc, __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
...@@ -75,51 +51,7 @@ struct GridwiseGemmPipeline_v1<AGridDesc, ...@@ -75,51 +51,7 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
#if 0 #if 1
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#elif 0
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -244,41 +176,38 @@ struct GridwiseGemmPipeline_v1<AGridDesc, ...@@ -244,41 +176,38 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
}; };
// 2-stage prefetch // 2-stage prefetch
template <typename AGridDesc, template <>
typename ABlockDesc, struct GridwiseGemmPipeline_v1<2>
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1<AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
2,
HasMainLoop>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return (num_loop / 2) > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc, static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
...@@ -395,4 +324,3 @@ struct GridwiseGemmPipeline_v1<AGridDesc, ...@@ -395,4 +324,3 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
}; };
} // namespace ck } // namespace ck
#endif
...@@ -4,62 +4,8 @@ ...@@ -4,62 +4,8 @@
namespace ck { namespace ck {
template <typename ABBlockTransferThreadGroup, template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
typename BlockGemmThreadGroup, struct GridwiseGemmPipeline_v2
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t NumPrefetch,
bool HasMainLoop>
struct GridwiseGemmPipeline_v2;
// 1-stage prefetch
template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
bool HasMainLoop>
struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
AGridDesc,
ABlockDesc,
ABlockTransfer,
AGridBuffer,
ABlockBuffer,
ABlockTransferStep,
BGridDesc,
BBlockDesc,
BBlockTransfer,
BGridBuffer,
BBlockBuffer,
BBlockTransferStep,
BlockwiseGemm,
CThreadBuffer,
1,
HasMainLoop>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -69,6 +15,31 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -69,6 +15,31 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
// TODO static assert // TODO static assert
} }
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 2 > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc, static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy, ABlockTransfer& a_block_copy,
...@@ -151,6 +122,11 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -151,6 +122,11 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
} }
} }
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf, static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm, const BlockwiseGemm& block_gemm,
...@@ -161,7 +137,6 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -161,7 +137,6 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
c_thread_buf.Clear(); c_thread_buf.Clear();
// main body // main body
// FIXME: HasMainLoop = (num_loop) > 2
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
{ {
index_t i = 0; index_t i = 0;
...@@ -205,6 +180,21 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -205,6 +180,21 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
} }
} }
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc, static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy, ABlockTransfer& a_block_copy,
...@@ -223,23 +213,24 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup, ...@@ -223,23 +213,24 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
{ {
if(ABBlockTransferThreadGroup::IsBelong()) if(ABBlockTransferThreadGroup::IsBelong())
{ {
RunABBlockTransferPipeline(a_grid_desc, RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
a_block_desc, a_block_desc,
a_block_copy, a_block_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_copy_step, a_block_copy_step,
b_grid_desc, b_grid_desc,
b_block_desc, b_block_desc,
b_block_copy, b_block_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_copy_step, b_block_copy_step,
num_loop); num_loop);
} }
else if(BlockGemmThreadGroup::IsBelong()) else if(BlockGemmThreadGroup::IsBelong())
{ {
RunBlockGemmPipeline(a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop); RunBlockGemmPipeline<HasMainLoop>(
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
} }
} }
}; };
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -21,7 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -41,17 +41,17 @@ __global__ void ...@@ -41,17 +41,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -125,6 +125,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -125,6 +125,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -190,10 +194,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -190,10 +194,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
...@@ -208,21 +208,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -208,21 +208,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false; return false;
// check NumGemmKPrefetchStage // check gridwise gemm pipeline
if constexpr(NumGemmKPrefetchStage == 1) const auto num_k_loop = K / KPerBlock;
{
// 1-stage prefetch always supported if(!GridwiseGemmPipe::IsSupported(num_k_loop))
}
else if constexpr(NumGemmKPrefetchStage == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K / KPerBlock) % 2 == 0))
{
return false;
}
}
else
{ {
return false; return false;
} }
...@@ -242,12 +231,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -242,12 +231,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; const index_t num_loop = K / KPerBlock;
return has_main_k0_block_loop; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -315,7 +303,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -315,7 +303,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -358,28 +346,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -358,28 +346,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_AK1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -389,28 +377,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -389,28 +377,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -429,7 +417,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -429,7 +417,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
...@@ -457,43 +445,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -457,43 +445,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
...@@ -609,8 +579,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -609,8 +579,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
BlockSize, // index_t BlockSize, ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp, CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
......
...@@ -22,7 +22,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -41,17 +41,17 @@ __global__ void ...@@ -41,17 +41,17 @@ __global__ void
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
} }
template <typename FloatAB, template <typename FloatAB,
...@@ -115,7 +115,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -115,7 +115,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock = using ThisThreadBlock =
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>; ThisThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
#if 0 #if 0
struct ABBlockTransferThreadGroup struct ABBlockTransferThreadGroup
...@@ -158,6 +158,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -158,6 +158,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using CShuffleBlockTransferThreadGroup = ThisThreadBlock; using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#endif #endif
#if 1
// gridwise GEMM pipeline
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
#else
// gridwise GEMM pipeline
using GridwiseGemmPipe = GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
NumGemmKPrefetchStage>;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -223,10 +233,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -223,10 +233,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
...@@ -241,21 +247,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -241,21 +247,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
return false; return false;
// check NumGemmKPrefetchStage // check gridwise gemm pipeline
if constexpr(NumGemmKPrefetchStage == 1) const auto num_k_loop = K / KPerBlock;
{
// 1-stage prefetch always supported if(!GridwiseGemmPipe::IsSupported(num_k_loop))
}
else if constexpr(NumGemmKPrefetchStage == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K / KPerBlock) % 2 == 0))
{
return false;
}
}
else
{ {
return false; return false;
} }
...@@ -275,12 +270,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -275,12 +270,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; const index_t num_loop = K / KPerBlock;
return has_main_k0_block_loop; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -348,7 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -348,7 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -493,63 +487,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -493,63 +487,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
#if 1 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
// gridwise GEMM pipeline a_block_desc_ak0_m_ak1,
const auto gridwise_gemm_pipeline = a_blockwise_copy,
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>, a_grid_buf,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>, a_block_buf,
remove_cvref_t<decltype(a_blockwise_copy)>, a_block_slice_copy_step,
remove_cvref_t<decltype(a_grid_buf)>, b_grid_desc_bk0_n_bk1,
remove_cvref_t<decltype(a_block_buf)>, b_block_desc_bk0_n_bk1,
remove_cvref_t<decltype(a_block_slice_copy_step)>, b_blockwise_copy,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>, b_grid_buf,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>, b_block_buf,
remove_cvref_t<decltype(b_blockwise_copy)>, b_block_slice_copy_step,
remove_cvref_t<decltype(b_grid_buf)>, blockwise_gemm,
remove_cvref_t<decltype(b_block_buf)>, c_thread_buf,
remove_cvref_t<decltype(b_block_slice_copy_step)>, num_k_block_main_loop);
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
#else
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_block_desc_ak0_m_ak1)>,
remove_cvref_t<decltype(a_blockwise_copy)>,
remove_cvref_t<decltype(a_grid_buf)>,
remove_cvref_t<decltype(a_block_buf)>,
remove_cvref_t<decltype(a_block_slice_copy_step)>,
remove_cvref_t<decltype(b_grid_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_block_desc_bk0_n_bk1)>,
remove_cvref_t<decltype(b_blockwise_copy)>,
remove_cvref_t<decltype(b_grid_buf)>,
remove_cvref_t<decltype(b_block_buf)>,
remove_cvref_t<decltype(b_block_slice_copy_step)>,
remove_cvref_t<decltype(blockwise_gemm)>,
remove_cvref_t<decltype(c_thread_buf)>,
NumGemmKPrefetchStage,
HasMainK0BlockLoop>{};
#endif
gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
......
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP #pragma once
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -22,7 +20,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +20,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainK0BlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -42,17 +40,17 @@ __global__ void ...@@ -42,17 +40,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -67,88 +65,6 @@ __global__ void ...@@ -67,88 +65,6 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
#else
ignore = gemm_desc_;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -187,7 +103,7 @@ template <index_t BlockSize, ...@@ -187,7 +103,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
index_t NumPrefetch = 1> index_t NumGemmKPrefetchStage = 1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false; return false;
// check NumPrefetch // check gridwise gemm pipeline
if constexpr(NumPrefetch == 1) const auto num_k_loop = K0 / K0PerBlock;
{
// 1-stage prefetch always supported if(!GridwiseGemmPipe::IsSupported(num_k_loop))
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
{ {
return false; return false;
} }
...@@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return grid_size; return grid_size;
} }
// TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; const index_t num_loop = K / (K0PerBlock * K1);
return has_main_k0_block_loop; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -379,7 +287,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -379,7 +287,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}(); }();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -499,7 +407,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -499,7 +407,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumPrefetch>( NumGemmKPrefetchStage>(
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -530,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -530,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumPrefetch>( NumGemmKPrefetchStage>(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
...@@ -547,7 +455,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -547,7 +455,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check // sanity check
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<ThisThreadBlock,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_k0_m_k1)>,
remove_cvref_t<decltype(a_block_desc_k0_m_k1)>, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
remove_cvref_t<decltype(a_blockwise_copy)>, a_block_desc_k0_m_k1,
remove_cvref_t<decltype(a_grid_buf)>, a_blockwise_copy,
remove_cvref_t<decltype(a_block_buf)>, a_grid_buf,
remove_cvref_t<decltype(a_block_slice_copy_step)>, a_block_buf,
remove_cvref_t<decltype(b_grid_desc_k0_n_k1)>, a_block_slice_copy_step,
remove_cvref_t<decltype(b_block_desc_k0_n_k1)>, b_grid_desc_k0_n_k1,
remove_cvref_t<decltype(b_blockwise_copy)>, b_block_desc_k0_n_k1,
remove_cvref_t<decltype(b_grid_buf)>, b_blockwise_copy,
remove_cvref_t<decltype(b_block_buf)>, b_grid_buf,
remove_cvref_t<decltype(b_block_slice_copy_step)>, b_block_buf,
remove_cvref_t<decltype(blockwise_gemm)>, b_block_slice_copy_step,
remove_cvref_t<decltype(c_thread_buf)>, blockwise_gemm,
NumPrefetch, c_thread_buf,
HasMainK0BlockLoop>{}; num_k_block_main_loop);
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// output: register to global memory // output: register to global memory
{ {
...@@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}; };
} // namespace ck } // namespace ck
#endif
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace ck { namespace ck {
template <index_t ThreadPerBlock> template <index_t ThreadPerBlock>
struct AnyThreadBlock struct ThisThreadBlock
{ {
static constexpr index_t kNumThread_ = ThreadPerBlock; static constexpr index_t kNumThread_ = ThreadPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment