"examples/vscode:/vscode.git/clone" did not exist on "4df9d4921862e8cb12fa87a43af9967077e39566"
Commit c5fd087e authored by aska-0096's avatar aska-0096
Browse files

Attn, skip b lds

parent 6e28a8ac
...@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec, static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec) const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( if constexpr(B0EnableLds)
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec), {
Number<K1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<K1>{});
}
else
{
return Transform::MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<WmmaK>{},
Number<LRepeat>{},
Number<LWaves>{},
Number<LPerWmma>{},
Number<K1>{});
}
} }
static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec, static auto MakeB1GridDescriptor(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec) const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( if constexpr(B1EnableLds)
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec), {
Number<L1>{}); return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<L1>{});
}
else
{
return Transform::MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Number<WmmaK>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWmma>{},
Number<L1>{});
}
} }
using AGridDesc = decltype(MakeAGridDescriptor({}, {})); using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor({}, {})); using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {})); using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor // InMemory Data Descriptor
AGridDesc, AGridDesc,
B0GridDesc_BK0_L_BK1, B0GridDesc,
B1GridDesc_BL0_N_BL1, B1GridDesc,
CGridDesc_M_N, CGridDesc_M_N,
// Tiling Family // Tiling Family
MPerBlock, MPerBlock,
...@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{ b0_grid_desc{
DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1( b1_grid_desc{
b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{ c_grid_desc_m_n_{
Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
...@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore = acc1_biases_gs_ms_ns_lengths; ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_ns_strides; ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseOp::CheckValidity(a_grid_desc, if(GridwiseOp::CheckValidity(
b0_grid_desc_bk0_l_bk1_, a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_))
b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Tensor Descriptors // Tensor Descriptors
AGridDesc a_grid_desc; AGridDesc a_grid_desc;
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_; B0GridDesc b0_grid_desc;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_; B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
...@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1DataType, B1DataType,
CDataType, CDataType,
DeviceOp::AGridDesc, DeviceOp::AGridDesc,
DeviceOp::B0GridDesc_BK0_L_BK1, DeviceOp::B0GridDesc,
DeviceOp::B1GridDesc_BL0_N_BL1, DeviceOp::B1GridDesc,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
B0ElementwiseOperation, B0ElementwiseOperation,
...@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc, arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_, arg.b0_grid_desc,
arg.b1_grid_desc_bl0_n_bl1_, arg.b1_grid_desc,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b0_element_op_, arg.b0_element_op_,
...@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
if(!GridwiseOp::CheckValidity(arg.a_grid_desc, if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc_bk0_l_bk1_, arg.b0_grid_desc,
arg.b1_grid_desc_bl0_n_bl1_, arg.b1_grid_desc,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseOp,
typename FloatA, typename ADataType,
typename FloatB0, typename B0DataType,
typename FloatB1, typename B1DataType,
typename FloatC, typename CDataType,
typename AGridDesc, typename AGridDesc,
typename B0GridDesc_BK0_L_BK1, typename B0GridDesc,
typename B1GridDesc_BL0_N_BL1, typename B1GridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
...@@ -41,13 +41,13 @@ __global__ void ...@@ -41,13 +41,13 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle( kernel_batched_gemm_softmax_gemm_wmma_cshuffle(
const FloatA* __restrict__ p_a_grid, const ADataType* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid, const B0DataType* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid, const B1DataType* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc, const AGridDesc a_grid_desc,
const B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1, const B0GridDesc b0_grid_desc,
const B1GridDesc_BL0_N_BL1 b1_grid_desc_l0_n_l1, const B1GridDesc b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -61,7 +61,7 @@ __global__ void ...@@ -61,7 +61,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -76,30 +76,30 @@ __global__ void ...@@ -76,30 +76,30 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset, p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_shared, p_shared,
a_grid_desc, a_grid_desc,
b0_grid_desc_bk0_l_bk1, b0_grid_desc,
b1_grid_desc_l0_n_l1, b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
c0_matrix_mask, c0_matrix_mask,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b0_grid; ignore = p_b0_grid;
ignore = p_b1_grid; ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc; ignore = a_grid_desc;
ignore = b0_grid_desc_bk0_l_bk1; ignore = b0_grid_desc;
ignore = b1_grid_desc_l0_n_l1; ignore = b1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b0_element_op; ignore = b0_element_op;
...@@ -115,13 +115,13 @@ __global__ void ...@@ -115,13 +115,13 @@ __global__ void
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] // Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] // Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <typename FloatA, template <typename ADataType,
typename FloatB0, typename B0DataType,
typename FloatAcc0, typename Acc0DataType,
typename FloatB1, typename B1DataType,
typename FloatAcc1, typename Acc1DataType,
typename FloatCShuffle, typename CShuffleDataType,
typename FloatC, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -129,8 +129,8 @@ template <typename FloatA, ...@@ -129,8 +129,8 @@ template <typename FloatA,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc, typename AGridDesc,
typename B0GridDesc_BK0_L_BK1, typename B0GridDesc,
typename B1GridDesc_BL0_N_BL1, typename B1GridDesc,
typename CGridDesc_M_N, typename CGridDesc_M_N,
index_t MPerBlock, index_t MPerBlock,
index_t LPerBlock, index_t LPerBlock,
...@@ -163,7 +163,7 @@ template <typename FloatA, ...@@ -163,7 +163,7 @@ template <typename FloatA,
index_t B0BlockTransferDstScalarPerVector_K1, index_t B0BlockTransferDstScalarPerVector_K1,
bool B0ThreadTransferSrcResetCoordinateAfterRun, bool B0ThreadTransferSrcResetCoordinateAfterRun,
bool B0EnableLds, bool B0EnableLds,
bool B0BlockLdsExtraN, bool B0BlockLdsExtraL,
typename B1BlockTransferThreadClusterLengths_L0_N_L1, typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -204,8 +204,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -204,8 +204,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto BL1 = Number<L1Value>{}; static constexpr auto BL1 = Number<L1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = 16; static constexpr auto WmmaK = 16;
static constexpr auto WmmaL = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
...@@ -250,6 +252,73 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -250,6 +252,73 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return a_block_desc; return a_block_desc;
} }
__host__ __device__ static constexpr auto MakeB0BlockDescriptor()
{
constexpr auto b0_block_desc = [&]() {
if constexpr(B0EnableLds)
{
// K0->L->BK1 Per Block
constexpr auto K0PerBlock = KPerBlock / BK1;
constexpr auto max_lds_align = BK1;
if constexpr(B0BlockLdsExtraL)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<LPerBlock>{}, BK1),
make_tuple(Number<LPerBlock + 1>{} * BK1, BK1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<LPerBlock>{}, BK1), max_lds_align);
}
}
else
{
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
// KWmma->NRepeat->NWave->NRow->NPerWmma->BK1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<KWmmaPerblock>{}, Number<LRepeat>{}, I1, I1, I1, BK1),
make_tuple(Number<LRepeat>{} * BK1, BK1, BK1, BK1, BK1, I1));
}
}();
return b0_block_desc;
}
__host__ __device__ static constexpr auto MakeB1BlockDescriptor()
{
constexpr auto b1_block_desc = [&]() {
if constexpr(B1EnableLds)
{
// L0->N->BL1 Per Block
constexpr auto max_lds_align = BL1;
if constexpr(B1BlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<L0PerBlock>{}, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + 1>{} * BL1, BL1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<L0PerBlock>{}, Number<NPerBlock>{}, BL1), max_lds_align);
}
}
else
{
constexpr auto LWmmaPerblock = LPerBlock / WmmaL;
// LWmma->NRepeat->NWave->NRow->LPerWmma->BL1 Per Thread
return make_naive_tensor_descriptor(
make_tuple(Number<LWmmaPerblock>{}, Number<NRepeat>{}, I1, I1, I1, BL1),
make_tuple(Number<NRepeat>{} * BL1, BL1, BL1, BL1, BL1, I1));
}
}();
return b1_block_desc;
}
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep() __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
{ {
constexpr auto a_block_copy_step = [&]() { constexpr auto a_block_copy_step = [&]() {
...@@ -270,6 +339,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -270,6 +339,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return a_block_copy_step; return a_block_copy_step;
} }
__host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep()
{
constexpr auto b0_block_copy_step = [&]() {
if constexpr(B0EnableLds)
{
constexpr auto K0PerBlock = KPerBlock / BK1;
return make_multi_index(K0PerBlock, 0, 0);
}
else
{
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return b0_block_copy_step;
}
__host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep()
{
constexpr auto b1_block_copy_step = [&]() {
if constexpr(B1EnableLds)
{
return make_multi_index(L0PerBlock, 0, 0);
}
else
{
constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0);
}
}();
return b1_block_copy_step;
}
// Describe how data read from (LDS/VGPR) buffer // Describe how data read from (LDS/VGPR) buffer
template <typename ABlockDesc_> template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
...@@ -323,26 +430,61 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -323,26 +430,61 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return a_wave_desc; return a_wave_desc;
} }
template <typename B0BlockDesc_BK0_L_BK1> template <typename B0BlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&)
{ {
constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0);
constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2); constexpr auto b0_wave_desc = [&]() {
constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma); if constexpr(B0EnableLds)
return transform_tensor_descriptor( {
B0BlockDesc_BK0_L_BK1{}, // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
make_tuple(make_pass_through_transform(Number<B_K0>{}), constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
make_unmerge_transform( constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})), return transform_tensor_descriptor(
make_pass_through_transform(Number<B_K1>{})), B0BlockDesc_{},
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
// KWmma_LRepeat_LWave_KRow_LPerWmma_K1 -> K0_LRepeat_Lwaves_LPerWmma_K1
constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I5);
// Workaround, Freeze transform
return transform_tensor_descriptor(
B0BlockDesc_{},
make_tuple(make_freeze_transform(I0),
make_pass_through_transform(Number<KWmma>{}),
make_pass_through_transform(Number<LRepeat>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
return b0_wave_desc;
} }
template <typename A1BlockDesc_AL0_M_AL1> template <typename A1BlockDesc_AL0_M_AL1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&) MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
{ {
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
...@@ -356,37 +498,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -356,37 +498,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
template <typename B1BlockDesc_BL0_N_BL1> template <typename B1BlockDesc_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&)
MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
{ {
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
return transform_tensor_descriptor( constexpr auto b1_wave_desc = [&]() {
B1BlockDesc_BL0_N_BL1{}, if constexpr(B1EnableLds)
make_tuple(make_pass_through_transform(Number<B_K0>{}), {
make_unmerge_transform( // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
make_pass_through_transform(Number<B_K1>{})), constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), return transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); B1BlockDesc_{},
} make_tuple(make_pass_through_transform(Number<B_L0>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
else
{
// LWmma_NRepeat_NWave_LRow_NPerWmma_L1 -> L0_NRepeat_Nwaves_NPerWmma_L1
constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I5);
__host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1() // Workaround, Freeze transform
{ return transform_tensor_descriptor(
// B matrix in LDS memory, dst of blockwise copy B1BlockDesc_{},
return make_naive_tensor_descriptor( make_tuple(make_freeze_transform(I0),
make_tuple(BK0, Number<LPerBlock>{}, BK1), make_pass_through_transform(Number<LWmma>{}),
make_tuple(Number<LPerBlock + B0BlockLdsExtraN>{} * BK1, BK1, I1)); make_pass_through_transform(Number<NRepeat>{}),
} make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(Number<B_L1>{})),
make_tuple(Sequence<3>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}));
}
}();
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1() return b1_wave_desc;
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BL0, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * BL1, BL1, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -410,31 +571,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -410,31 +571,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
const index_t gemm0_bytes_end = const index_t gemm0_bytes_end =
(SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) + (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) +
SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0)); SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType));
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + (SharedMemTrait::b1_block_space_offset +
SharedMemTrait::b1_block_space_size_aligned * sizeof(FloatB1)); SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType));
const index_t softmax_bytes_end = const index_t softmax_bytes_end =
SharedMemTrait::reduction_space_offset + SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned * sizeof(FloatAcc0); SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType);
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
CheckValidity(const AGridDesc& a_grid_desc, const B0GridDesc& b0_grid_desc,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_bk0_l_bk1, const B1GridDesc& b1_grid_desc,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1, const CGridDesc_M_N& c_grid_desc_m_n,
const CGridDesc_M_N& c_grid_desc_m_n, const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map)
{ {
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(LPerBlock % (LPerWmma * LRepeat)) == 0, (LPerBlock % (LPerWmma * LRepeat)) == 0,
...@@ -455,10 +615,40 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -455,10 +615,40 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
} }
}; };
const auto GetB0ProblemsizeLK = [&]() {
if constexpr(B0EnableLds)
{
return make_tuple(b0_grid_desc.GetLength(I1),
b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2));
}
else
{
return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) *
b0_grid_desc.GetLength(I4),
b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) *
b0_grid_desc.GetLength(I5));
}
};
const auto GetB1ProblemsizeNL = [&]() {
if constexpr(B1EnableLds)
{
return make_tuple(b1_grid_desc.GetLength(I1),
b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2));
}
else
{
return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) *
b1_grid_desc.GetLength(I4),
b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) *
b1_grid_desc.GetLength(I5));
}
};
const auto M = GetAProblemsizeMK()[I0]; const auto M = GetAProblemsizeMK()[I0];
const auto L = b0_grid_desc_bk0_l_bk1.GetLength(I1); const auto L = GetB0ProblemsizeLK()(I0);
const auto K = GetAProblemsizeMK()[I1]; const auto K = GetAProblemsizeMK()[I1];
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1); const auto N = GetB1ProblemsizeNL()(I0);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
{ {
...@@ -567,17 +757,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -567,17 +757,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
max_lds_align) max_lds_align)
: 0; : 0;
static constexpr auto b0_block_space_size_aligned = static constexpr auto b0_block_space_size_aligned =
B0EnableLds B0EnableLds ? math::integer_least_multiple(
? math::integer_least_multiple( MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align)
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1().GetElementSpaceSize(), : 0;
max_lds_align)
: 0;
static constexpr auto b1_block_space_size_aligned = static constexpr auto b1_block_space_size_aligned =
B1EnableLds B1EnableLds ? math::integer_least_multiple(
? math::integer_least_multiple( MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align)
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1().GetElementSpaceSize(), : 0;
max_lds_align)
: 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned; static constexpr auto b0_block_space_offset = a_block_space_size_aligned;
...@@ -599,14 +785,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -599,14 +785,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename C0MatrixMask, typename C0MatrixMask,
typename Block2CTileMap = DefaultBlock2CTileMap> typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid, __device__ static void Run(const ADataType* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid, const B0DataType* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid, const B1DataType* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc& a_grid_desc, const AGridDesc& a_grid_desc,
const B0GridDesc_BK0_L_BK1& b0_grid_desc_k0_l_k1, const B0GridDesc& b0_grid_desc,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1, const B1GridDesc& b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -623,9 +809,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -623,9 +809,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize()); p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc_k0_l_k1.GetElementSpaceSize()); p_b0_grid, b0_grid_desc.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_l0_n_l1.GetElementSpaceSize()); p_b1_grid, b1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -648,17 +834,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -648,17 +834,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
}
}();
constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto a_block_desc = MakeABlockDescriptor();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1(); constexpr auto b0_block_desc = MakeB0BlockDescriptor();
auto a_block_trait = [&](){ auto a_block_trait = [&](){
// A matrix blockwise copy // A matrix blockwise copy
...@@ -666,7 +843,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -666,7 +843,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{ {
constexpr auto AK0PerBlock = KPerBlock/ AK1; constexpr auto AK0PerBlock = KPerBlock/ AK1;
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<ADataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
SharedMemTrait::a_block_space_size_aligned); SharedMemTrait::a_block_space_size_aligned);
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -677,8 +854,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -677,8 +854,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename BlockSliceLengths, */ Sequence<AK0PerBlock, MPerBlock, AK1>, /* typename BlockSliceLengths, */ Sequence<AK0PerBlock, MPerBlock, AK1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA, /* typename SrcData, */ ADataType,
/* typename DstData, */ FloatA, /* typename DstData, */ ADataType,
/* typename SrcDesc, */ decltype(a_grid_desc), /* typename SrcDesc, */ decltype(a_grid_desc),
/* typename DstDesc, */ decltype(a_block_desc), /* typename DstDesc, */ decltype(a_block_desc),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
...@@ -705,13 +882,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -705,13 +882,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical // Limitation: NumDim of Src and Dst descriptor should be identical
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatA, ThreadwiseTensorSliceTransfer_v2<ADataType,
FloatA, ADataType,
decltype(a_grid_desc), decltype(a_grid_desc),
decltype(a_block_desc), decltype(a_block_desc),
Sequence<Number<KWmmaPerBlock>{}, Sequence<Number<KWmmaPerBlock>{},
...@@ -736,20 +913,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -736,20 +913,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return make_tuple(a_block_buf, a_blockwise_copy); return make_tuple(a_block_buf, a_blockwise_copy);
} }
}; };
auto b0_block_trait = [&](){
if constexpr(B0EnableLds)
{
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<B0DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
SharedMemTrait::b0_block_space_size_aligned);
// B matrix blockwise copy auto b0_blockwise_copy =
auto b0_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation, B0ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, LPerBlock, BK1>, Sequence<BK0, LPerBlock, BK1>,
B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferThreadClusterArrangeOrder,
FloatB0, B0DataType,
FloatB0, B0DataType,
decltype(b0_grid_desc_k0_l_k1), decltype(b0_grid_desc),
decltype(b0_block_desc_k0perblock_lperblock_k1), decltype(b0_block_desc),
B0BlockTransferSrcAccessOrder, B0BlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
B0BlockTransferSrcVectorDim, B0BlockTransferSrcVectorDim,
...@@ -760,15 +943,57 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -760,15 +943,57 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
1, 1,
B0ThreadTransferSrcResetCoordinateAfterRun, B0ThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b0_grid_desc_k0_l_k1, b0_grid_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
b0_element_op, b0_element_op,
b0_block_desc_k0perblock_lperblock_k1, b0_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
return make_tuple(b0_block_buf, b0_blockwise_copy);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>(
b0_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b0_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<B0DataType,
B0DataType,
decltype(b0_grid_desc),
decltype(b0_block_desc),
Sequence<Number<KWmmaPerBlock>{},
Number<LRepeat>{},
I1,
I1,
I1,
Number<K1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B0BlockTransferSrcScalarPerVector,
B0ThreadTransferSrcResetCoordinateAfterRun,
true>(
b0_grid_desc,
make_multi_index(0,
0/(LWaves * LPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b0_block_buf, b0_blockwise_copy);
}
};
auto a_block_buf = a_block_trait()[I0]; auto a_block_buf = a_block_trait()[I0];
auto a_blockwise_copy = a_block_trait()[I1]; auto a_blockwise_copy = a_block_trait()[I1];
auto b0_block_buf = b0_block_trait()[I0];
auto b0_blockwise_copy = b0_block_trait()[I1];
/*******************************************************************************/ /*******************************************************************************/
// Gemm0 // Gemm0
...@@ -776,11 +1001,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -776,11 +1001,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto blockwise_gemm0 = BlockwiseGemmWMMA< auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockSize, BlockSize,
FloatA, ADataType,
FloatB0, B0DataType,
FloatAcc0, Acc0DataType,
decltype(MakeAWaveDescriptor(a_block_desc)), decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)), decltype(MakeB0WaveDescriptor(b0_block_desc)),
MPerBlock, MPerBlock,
LPerBlock, LPerBlock,
KPerBlock, KPerBlock,
...@@ -816,16 +1041,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -816,16 +1041,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
/*******************************************************************************/ /*******************************************************************************/
// LDS allocation for A and B: be careful of alignment
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset,
SharedMemTrait::b0_block_space_size_aligned);
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
constexpr auto b0_block_slice_copy_step = make_multi_index(BK0, 0, 0); constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep();
const auto a_block_reset_copy_step = [&](){ const auto a_block_reset_copy_step = [&](){
if constexpr(AEnableLds){ if constexpr(AEnableLds){
...@@ -836,14 +1055,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -836,14 +1055,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
} }
}(); }();
const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0); const auto b0_block_reset_copy_step = [&](){
if constexpr(B0EnableLds){
return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0);
}
else{
return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0);
}
}();
const auto K = [&](){
if constexpr(AEnableLds){
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
}
else{
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * a_grid_desc.GetLength(I5);
}
}();
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
/*******************************************************************************/ /*******************************************************************************/
// softmax // softmax
/*******************************************************************************/ /*******************************************************************************/
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAcc0*>(p_shared) + SharedMemTrait::reduction_space_offset, static_cast<Acc0DataType*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::reduction_space_size_aligned); SharedMemTrait::reduction_space_size_aligned);
// get acc0 7D thread cluster // get acc0 7D thread cluster
constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
...@@ -879,7 +1114,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -879,7 +1114,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs)); make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize, auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatAcc0, Acc0DataType,
decltype(threadid_to_l_n_thread_cluster_adaptor), decltype(threadid_to_l_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_l), decltype(thread_cluster_desc_m_l),
decltype(thread_slice_desc_m_l)>{}; decltype(thread_slice_desc_m_l)>{};
...@@ -889,15 +1124,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -889,15 +1124,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0; running_sum = 0;
running_sum_new = 0; running_sum_new = 0;
running_max = NumericLimits<FloatAcc0>::Lowest(); running_max = NumericLimits<Acc0DataType>::Lowest();
running_max_new = NumericLimits<FloatAcc0>::Lowest(); running_max_new = NumericLimits<Acc0DataType>::Lowest();
/*******************************************************************************/ /*******************************************************************************/
// set up Gemm1 // set up Gemm1
/*******************************************************************************/ /*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0);
// Acc0 thread buffer -> A1 thread buffer -> blockwise gemm // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
// A1 matrix in VGPR // A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
...@@ -915,8 +1146,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -915,8 +1146,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatAcc0, Acc0DataType,
FloatA, ADataType,
decltype(acc0_thread_desc_l0perblock_mperblock_l1), decltype(acc0_thread_desc_l0perblock_mperblock_l1),
decltype(a1_thread_desc_l0perblock_mperblock_l1), decltype(a1_thread_desc_l0perblock_mperblock_l1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -925,8 +1156,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -925,8 +1156,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
2, 2,
laccvgprs>{tensor_operation::element_wise::PassThrough{}}; laccvgprs>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
auto b1_blockwise_copy = a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
constexpr auto b1_block_desc = MakeB1BlockDescriptor();
auto b1_block_trait = [&](){
if constexpr(B1EnableLds)
{
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
SharedMemTrait::b1_block_space_size_aligned);
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ B1ElementwiseOperation, /* typename SrcElementwiseOperation, */ B1ElementwiseOperation,
/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough, /* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough,
...@@ -934,10 +1176,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -934,10 +1176,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>, /* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>,
/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1, /* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1,
/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatB1, /* typename SrcData, */ B1DataType,
/* typename DstData, */ FloatB1, /* typename DstData, */ B1DataType,
/* typename SrcDesc, */ decltype(b1_grid_desc_l0_n_l1), /* typename SrcDesc, */ decltype(b1_grid_desc),
/* typename DstDesc, */ decltype(b1_block_desc_l0perblock_nperblock_l1), /* typename DstDesc, */ decltype(b1_block_desc),
/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, /* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim,
...@@ -949,26 +1191,64 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -949,26 +1191,64 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord /* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b1_grid_desc_l0_n_l1, b1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
b1_block_desc_l0perblock_nperblock_l1, b1_block_desc,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( return make_tuple(b1_block_buf, b1_blockwise_copy);
a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); }
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( else
static_cast<FloatB1*>(p_shared)+ SharedMemTrait::b1_block_space_offset, {
SharedMemTrait::b1_block_space_size_aligned); // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL;
auto b1_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B1DataType>(
b1_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto b1_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<B1DataType,
B1DataType,
decltype(b1_grid_desc),
decltype(b1_block_desc),
Sequence<Number<LWmmaPerBlock>{},
Number<NRepeat>{},
I1,
I1,
I1,
Number<L1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B1BlockTransferSrcScalarPerVector,
B1ThreadTransferSrcResetCoordinateAfterRun,
true>(
b1_grid_desc,
make_multi_index(0,
n_block_data_idx_on_grid/(NWaves * NPerWmma),
get_thread_local_1d_id() / 32,
(get_thread_local_1d_id() % 32 )/ 16,
get_thread_local_1d_id() % 16,
0));
return make_tuple(b1_block_buf, b1_blockwise_copy);
}
};
auto b1_block_buf = b1_block_trait()[I0];
auto b1_blockwise_copy = b1_block_trait()[I1];
constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep();
auto blockwise_gemm1 = auto blockwise_gemm1 =
BlockwiseGemmWMMA<BlockSize, BlockwiseGemmWMMA<BlockSize,
FloatA, ADataType,
FloatB1, B1DataType,
FloatAcc1, Acc1DataType,
decltype(MakeA1BlockDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)), decltype(MakeA1WaveDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)), decltype(MakeB1WaveDescriptor(b1_block_desc)),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
LTilePerBlock, LTilePerBlock,
...@@ -983,11 +1263,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -983,11 +1263,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock; const auto L = [&](){
if constexpr(B0EnableLds){
return b0_grid_desc.GetLength(I1);
}
else{
return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I4);
}
}();
const index_t num_gemm1_l_block_outer_loop = L / LPerBlock;
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock; constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock;
// Initialize C // Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, Acc1DataType, acc1_thread_buf.Size(), true> c_thread_buf;
c_thread_buf.Clear(); c_thread_buf.Clear();
/*******************************************************************************/ /*******************************************************************************/
...@@ -1014,8 +1303,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1014,8 +1303,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b0_grid_desc_k0_l_k1, b0_grid_desc,
b0_block_desc_k0perblock_lperblock_k1, b0_block_desc,
b0_blockwise_copy, b0_blockwise_copy,
b0_grid_buf, b0_grid_buf,
b0_block_buf, b0_block_buf,
...@@ -1106,20 +1395,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1106,20 +1395,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
acc1_thread_buf.Clear(); acc1_thread_buf.Clear();
// preload data into LDS // preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf); b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1, b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
b1_block_slice_copy_step); b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf); b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
// main body // main body
if constexpr(num_gemm1_l_block_inner_loop > 1) if constexpr(num_gemm1_l_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) {
// Data cast from FloatAcc0 to FloatA happen here // Data cast from Acc0DataType to ADataType happen here
a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1, a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1,
make_tuple(Number<i * A1ThreadSliceL0PerBlock>{}, I0, I0), make_tuple(Number<i * A1ThreadSliceL0PerBlock>{}, I0, I0),
acc0_thread_buf, acc0_thread_buf,
...@@ -1127,7 +1416,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1127,7 +1416,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf); b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
block_sync_lds(); block_sync_lds();
...@@ -1135,10 +1424,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1135,10 +1424,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds(); block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1, b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
b1_block_slice_copy_step); b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf); b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
}); });
} }
// tail // tail
...@@ -1177,9 +1466,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1177,9 +1466,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{}; auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatAcc1 acc1 = acc1_thread_buf[I]; // P*V Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
FloatAcc1 c = c_thread_buf[I]; // O Acc1DataType c = c_thread_buf[I]; // O
FloatAcc1 c_new = Acc1DataType c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) / math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; running_sum_new[iM];
...@@ -1190,7 +1479,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1190,7 +1479,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
a_block_reset_copy_step); // rewind K a_block_reset_copy_step); // rewind K
b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_k0_l_k1, b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc,
b0_block_reset_copy_step); // rewind K and step N b0_block_reset_copy_step); // rewind K and step N
// update before next j iteration // update before next j iteration
...@@ -1220,7 +1509,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1220,7 +1509,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
...@@ -1268,8 +1557,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1268,8 +1557,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc1, ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
FloatCShuffle, CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs), decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs), decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -1307,8 +1596,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -1307,8 +1596,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, CShuffleDataType, // typename SrcData,
FloatC, // typename DstData, CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
...@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc.GetElementSpaceSize()); b_block_desc.GetElementSpaceSize());
// Limitation: NumDim of Src and Dst descriptor should be identical // Limitation: NumDim of Src and Dst descriptor should be identical
......
...@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename BGridDesc_L_K,
typename WmmaK,
typename LRepeat,
typename LWaves,
typename LPerWmma,
typename BK1>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&,
const LRepeat&,
const LWaves&,
const LPerWmma&,
const BK1&)
{
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = WmmaK{} / BK1{};
return transform_tensor_descriptor(
b_grid_desc_l_k,
make_tuple(make_unmerge_transform(make_tuple(BKWmma, BKRow, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
// //
// B1 // B1
// //
...@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename BGridDesc_N_L,
typename WmmaL,
typename NRepeat,
typename NWaves,
typename NPerWmma,
typename BL1>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&,
const NRepeat&,
const NWaves&,
const NPerWmma&,
const BL1&)
{
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = WmmaL{} / BL1{};
return transform_tensor_descriptor(
b_grid_desc_n_l,
make_tuple(make_unmerge_transform(make_tuple(BLWmma, BLRow, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
// //
// C // C
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment