Commit d656321b authored by wangshaojie6's avatar wangshaojie6
Browse files

add 4 stage one

parent 575a50dd
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "device_gemm_xdl_producer_consumer_cshuffle.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -43,12 +44,23 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -43,12 +44,23 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
#if 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>; < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>;
#else
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 512
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 1
#endif #endif
......
#pragma once
#include "common_header.hpp"
namespace ck {
template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
index_t NumGemmKPrefetchStage>
struct GridwiseGemmPipelineProducerConsumer;
// 1-stage prefetch
template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
struct GridwiseGemmPipelineProducerConsumer<ABBlockTransferThreadGroup, BlockGemmThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 2 > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_sync_lds();
// move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_sync_lds();
// LDS write num_loop - 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
// GEMM num_loop - 1
}
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
a_block_desc,
a_block_copy,
a_grid_buf,
a_block_buf,
a_block_copy_step,
b_grid_desc,
b_block_desc,
b_block_copy,
b_grid_buf,
b_block_buf,
b_block_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
RunBlockGemmPipeline<HasMainLoop>(
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
}
}
};
} // namespace ck
...@@ -459,12 +459,12 @@ struct GridwiseGemmPipeline_v2<4> ...@@ -459,12 +459,12 @@ struct GridwiseGemmPipeline_v2<4>
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{ {
// TODO: improve applicability // TODO: improve applicability
return num_loop > 4; return num_loop % 4 == 0;
} }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop > 4; return num_loop / 4 > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -498,34 +498,15 @@ struct GridwiseGemmPipeline_v2<4> ...@@ -498,34 +498,15 @@ struct GridwiseGemmPipeline_v2<4>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
// global read 0 static_for<0, 4, 1>{}([&](auto i_pre){
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); // global read i_pre
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_pre>{});
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_pre>{});
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// move to 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// global read 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I2);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I2);
// move to 3 // move to i_pre + 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
// global read 3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I3);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I3);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -536,99 +517,36 @@ struct GridwiseGemmPipeline_v2<4> ...@@ -536,99 +517,36 @@ struct GridwiseGemmPipeline_v2<4>
{ {
do do
{ {
// move to i + 4 static_for<0, 4, 1>{}([&](auto i_main){
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// global Read i + 4
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
// LDS write i
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// global Read i + 4
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
block_sync_lds();
// GEMM i // LDS write i_main
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_main>{});
// global Read i_main + 3
block_sync_lds(); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, Number<i_main>{});
// move to i + 5
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
// global read i + 5
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// global read i + 5
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
block_sync_lds();
// GEMM i + 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write i_main
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_main>{});
// global Read i_main + 3
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<i_main>{});
// move to i + 6 // move to i_main + 3
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I2);
// global read i + 6
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I2);
// LDS write i + 2
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I2);
// global read i + 6
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I2);
block_sync_lds(); block_sync_lds();
// GEMM i + 2 // GEMM i_main
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
});
// move to i + 7
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I3);
// global read i + 7
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I3);
// LDS write i + 3
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I3);
// global read i + 7
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I3);
block_sync_lds();
// GEMM i + 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
i += 4; i += 4;
} while(i < (num_loop - 4)); } while(i < (num_loop - 4));
} }
// tail // tail
if (i == num_loop - 4)
{
static_for<0, I4, 1>{}([&](auto i_res){ static_for<0, I4, 1>{}([&](auto i_res){
// Write num_loop - 3 // Write num_loop - 3
...@@ -642,61 +560,6 @@ struct GridwiseGemmPipeline_v2<4> ...@@ -642,61 +560,6 @@ struct GridwiseGemmPipeline_v2<4>
block_sync_lds(); block_sync_lds();
}); });
}
// tail
if (i == num_loop - 3)
{
static_for<0, I3, 1>{}([&](auto i_res){
// Write num_loop - 3
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop - 3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
else if (i == num_loop - 2)
{
static_for<0, I2, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
// tail
else if (i == num_loop - 1)
{
static_for<0, I1, 1>{}([&](auto i_res){
// Write num_loop
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, Number<i_res>{});
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, Number<i_res>{});
block_sync_lds();
// GEMM num_loop
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
});
}
} }
}; };
......
...@@ -233,7 +233,7 @@ template <index_t BlockSize, ...@@ -233,7 +233,7 @@ template <index_t BlockSize,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
bool BBlockLdsExtraN1 = false, bool BBlockLdsExtraN1 = false,
index_t NumGemmKPrefetchStage = 3> index_t NumGemmKPrefetchStage = 4>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment