Commit 5df713ef authored by aska-0096's avatar aska-0096
Browse files

save progress

parent a6b2f1c1
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/* /*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------| |-----------------|
Gemm0 Gemm0
|-------------------------------------| |-------------------------------------|
...@@ -39,7 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -39,7 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using ADataType = F16;
using B0DataType = F16; using B0DataType = F16;
using B1DataType = F16; using B1DataType = F16;
using AccDataType = F32; using Acc0DataType = F32;
using Acc1DataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
...@@ -67,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial ...@@ -67,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -76,11 +77,12 @@ using DeviceGemmInstance = ...@@ -76,11 +77,12 @@ using DeviceGemmInstance =
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc0DataType,
Acc1BiasDataType, Acc1BiasDataType,
AccDataType, Acc1DataType,
CShuffleDataType, CShuffleDataType,
CDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
Acc0ElementOp, Acc0ElementOp,
...@@ -91,21 +93,21 @@ using DeviceGemmInstance = ...@@ -91,21 +93,21 @@ using DeviceGemmInstance =
TensorSpecB0, TensorSpecB0,
TensorSpecB1, TensorSpecB1,
TensorSpecC, TensorSpecC,
1,
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // LPerBlock
32, // KPerBlock 4, // K0PerBlock
64, // Gemm1NPerBlock 8, // K1
32, // Gemm1KPerBlock 64, // NPerBlock
8, // AK1 4, // L0PerBlock
8, // BK1 8, // L1
2, // B1K1 16, // MPerWMMA
32, // MPerXDL 16, // LPerWMMA
32, // NPerXDL 16, // NPerWMMA
1, // MXdlPerWave //Per repeat = wave_m = wave_num, wave_n = 1
4, // NXdlPerWave 1, // MRepeat
2, // Gemm1NXdlPerWave 8, // LRepeat
4, // NRepeat
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -113,44 +115,44 @@ using DeviceGemmInstance = ...@@ -113,44 +115,44 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<4, 64, 1>, // BBlockTransfer S<4, 64, 1>, // B0BlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
S<16, 16, 1>, // B1BlockTransfer S<4, 64, 1>, // B1BlockTransfer
S<0, 2, 1>, S<1, 0, 2>,
S<0, 2, 1>, S<1, 0, 2>,
1, 1,
4, 8,
2, 8,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
AccDataType, Acc0DataType,
AccDataType, Acc1DataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
Acc0ElementOp>; Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out // Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<Acc0DataType, ADataType, Acc0DataType>;
// Ref Gemm1: fp16 in, fp16 out // Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType, B1DataType,
CDataType, CDataType,
AccDataType, Acc1DataType,
AElementOp, AElementOp,
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
......
...@@ -198,7 +198,7 @@ int run(int argc, char* argv[]) ...@@ -198,7 +198,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
......
...@@ -129,11 +129,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -129,11 +129,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); // using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle( // __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle(
Tuple4 a_origin = CalculateAThreadOriginDataIndex(), // Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex()) // Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin) // : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
...@@ -303,8 +304,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -303,8 +304,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1, B_K1,
B_K1>; B_K1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
}; };
// block wise level pipe designed for inline asm // block wise level pipe designed for inline asm
...@@ -425,6 +428,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -425,6 +428,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(Number<m0>{},
blk_idx[I0],
waveId_m,
Number<n0>{},
waveId_n,
blk_idx[I1],
blk_idx[I2]);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO() __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
...@@ -438,6 +460,30 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -438,6 +460,30 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
NPerBlock % (NPerWMMA * NRepeat) == 0, NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
I1,
Number<NRepeat>{},
I1,
I1,
NAccVgprs));
}
// Thread level, register decriptor. Vector-write // Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
...@@ -483,6 +529,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -483,6 +529,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
} }
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Provide dimension size // Provide dimension size
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
......
...@@ -22,186 +22,97 @@ namespace ck { ...@@ -22,186 +22,97 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, // Computes C = A * B0 * B1
typename FloatAB, // MN = MK * KL * LN
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
c0_matrix_mask);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1) // ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG, template <index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimL,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimN,
typename ADataType, typename ADataType,
typename BDataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc0DataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename GemmAccDataType, typename Acc1DataType,
typename CShuffleDataType, typename CShuffleDataType,
typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename B0ElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
TensorSpecialization ASpec, TensorSpecialization ASpec,
TensorSpecialization BSpec, TensorSpecialization B0Spec,
TensorSpecialization B1Spec, TensorSpecialization B1Spec,
TensorSpecialization CSpec, TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage, ck::index_t BlockSize,
index_t BlockSize, ck::index_t MPerBlock,
index_t MPerBlock, ck::index_t LPerBlock,
index_t NPerBlock, // Gemm0NPerBlock ck::index_t K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
index_t KPerBlock, // Gemm0KPerBlock ck::index_t K1, //
index_t Gemm1NPerBlock, ck::index_t NPerBlock,
index_t Gemm1KPerBlock, ck::index_t L0PerBlock,
index_t AK1, ck::index_t L1,
index_t BK1, ck::index_t MPerWMMA,
index_t B1K1, ck::index_t LPerWMMA,
index_t MPerXDL, ck::index_t NPerWMMA,
index_t NPerXDL, ck::index_t MRepeat,
index_t MXdlPerWave, ck::index_t LRepeat,
index_t NXdlPerWave, ck::index_t NRepeat,
index_t Gemm1NXdlPerWave, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsExtraM, bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, typename B0BlockTransferThreadClusterLengths_K0_L_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename B0BlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename B0BlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, ck::index_t B0BlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, ck::index_t B0BlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, ck::index_t B0BlockTransferDstScalarPerVector_K1,
bool BBlockLdsExtraN, bool B0BlockLdsAddExtraL,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim, ck::index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector, ck::index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1, ck::index_t B1BlockTransferDstScalarPerVector_L1,
bool B1BlockLdsExtraN, bool B1BlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> ck::index_t NumPrefetch = 1,
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, : public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimL,
NumDimK, NumDimK,
NumDimO, NumDimN,
ADataType, ADataType,
BDataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, B0ElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MaskingSpec> MaskingSpec>
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
...@@ -210,64 +121,69 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -210,64 +121,69 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN; static constexpr index_t NumDimGemm0N = NumDimL;
static constexpr index_t NumDimGemm0K = NumDimK; static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM; static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO; static constexpr index_t NumDimGemm1N = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimL;
#endif
static constexpr index_t KPerBlock = K0PerBlock * K1;
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>, Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>, Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec, GemmSpec,
ASpec, ASpec,
BSpec, B0Spec,
B1Spec, B1Spec,
CSpec>; CSpec>;
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{}); Number<K1>{});
} }
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
Number<BK1>{}); Number<K1>{});
} }
static auto static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_gemm1ns_gemm1ks_strides_vec), b1_gs_ns_ls_strides_vec),
Number<B1K1>{}); Number<L1>{});
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B0GridDesc_BK0_L_BK1 = decltype(MakeB0GridDescriptor_BK0_L_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BL0_N_BL1 = decltype(MakeB1GridDescriptor_BL0_N_BL1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
...@@ -286,12 +202,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -286,12 +202,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
const CGridDesc_G_M_N& c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{ {
} }
...@@ -301,14 +217,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -301,14 +217,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{ {
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
...@@ -318,208 +234,202 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -318,208 +234,202 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
}; };
// GridwiseGemm // GridwiseOp
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle<
ADataType, // TODO: distinguish A/B datatype // DataType Family
GemmAccDataType, ADataType,
B0DataType,
Acc0DataType,
B1DataType,
Acc1DataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
// ElementwiseOp Family
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, B0ElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, B0GridDesc_BK0_L_BK1,
B1GridDesc_BK0_N_BK1, B1GridDesc_BL0_N_BL1,
CGridDesc_M_N, CGridDesc_M_N,
NumGemmKPrefetchStage, // Tiling Family
BlockSize,
MPerBlock, MPerBlock,
LPerBlock,
K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
K1, //
NPerBlock, NPerBlock,
KPerBlock, L0PerBlock,
Gemm1NPerBlock, L1,
Gemm1KPerBlock, MPerWMMA,
AK1, LPerWMMA,
BK1, NPerWMMA,
B1K1, MRepeat,
MPerXDL, LRepeat,
NPerXDL, NRepeat,
MXdlPerWave, // ThreadCluster Family
NXdlPerWave, BlockSize,
Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_K1,
true, true,
ABlockLdsExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterLengths_K0_L_K1,
BBlockTransferThreadClusterArrangeOrder, B0BlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, B0BlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, B0BlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, B0BlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, B0BlockTransferDstScalarPerVector_K1,
true, true,
BBlockLdsExtraN, B0BlockLdsAddExtraL,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim, B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector, B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1, B1BlockTransferDstScalarPerVector_L1,
false, false,
B1BlockLdsExtraN, B1BlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMRepeatPerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
NumPrefetch,
LoopSched,
PipelineVer>;
// Argument // Argument
// FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(
const ADataType* p_a_grid, const ADataType* p_a_grid,
const BDataType* p_b_grid, const B0DataType* p_b0_grid,
const B1DataType* p_b1_grid, const B1DataType* p_b1_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const index_t M01,
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides const index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b0_grid_{p_b0_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b0_grid_desc_bk0_l_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( b1_grid_desc_bl0_n_bl1_{
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{
c_gs_ms_gemm1ns_strides)}, Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{ b0_grid_desc_g_l_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1_grid_desc_g_n_l_{
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{
c_gs_ms_gemm1ns_strides)}, Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b0_element_op_{b0_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]}, b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1], b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1], b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{ compute_ptr_offset_of_batch_{
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_} a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
ignore = p_acc1_biases; ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths; ignore = acc0_biases_gs_ms_ls_lengths;
ignore = acc0_biases_gs_ms_ns_strides; ignore = acc0_biases_gs_ms_ls_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_biases_gs_ms_ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides; ignore = acc1_biases_gs_ms_ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseOp::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b0_grid_desc_bk0_l_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bl0_n_bl1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
} }
} }
void Print() const // Pointers
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
}
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const B0DataType* p_b0_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
// tensor descriptor // Tensor Descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-c-tile map // Block to Tile mapping
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_;
// element-wise op // ElementwiseOp
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; B0ElementwiseOperation b0_element_op_;
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
...@@ -527,15 +437,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -527,15 +437,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// check C0 masking and padding // check C0 masking and padding
C0MatrixMask c0_matrix_mask_; C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // Strides for the last M/N/K dimensions of A/B0/B1/C
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_; // for sanity check of vector load/store
std::vector<index_t> raw_lengths_mz_lz_kz_nz_;
std::vector<index_t> a_mz_kz_strides_; std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_; std::vector<index_t> b0_lz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_; std::vector<index_t> b1_nz_lz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_nz_strides_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; // Batch Offset
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_;
}; };
// Invoker // Invoker
...@@ -545,38 +457,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -545,38 +457,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!DeviceOp::IsSupportedArgument(arg)) const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
{
throw std::runtime_error("wrong! unsupported argument");
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
// Gemm0_K const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
auto launch_kernel = [&](auto has_main_k_block_loop_) { const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< GridwiseOp,
GridwiseGemm, ADataType,
ADataType, // TODO: distiguish A/B datatype B0DataType,
B1DataType,
CDataType, CDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::B0GridDesc_BK0_L_BK1,
DeviceOp::B1GridDesc_BL0_N_BL1,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, B0ElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_>; typename GridwiseOp::DefaultBlock2CTileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -584,36 +490,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -584,36 +490,32 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b0_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b0_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
arg.b1_element_op_, arg.b1_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.c0_matrix_mask_); arg.c0_matrix_mask_,
arg.block_2_ctile_map_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need if(GridwiseOp::CalculateHasMainKBlockLoop(K))
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); return launch_kernel(integral_constant<bool, false>{});
} }
return ave_time;
} }
// polymorphic // polymorphic
...@@ -632,25 +534,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -632,25 +534,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG if(ck::get_device_name() == "gfx1100")
arg.Print(); {
#endif if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
return false;
}
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
return false;
}
}
else
{ {
return false; return false;
} }
// TODO ANT: Check if tensor specialization & strides mismatch if(!GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b0_grid_desc_bk0_l_bk1_,
arg.b1_grid_desc_bl0_n_bl1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_n = arg.b1_grid_desc_bl0_n_bl1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_n == b1_n))
{ {
return false; return false;
} }
...@@ -658,19 +575,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -658,19 +575,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
// Check scalar per vector requirement // Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
const auto c_extent_lowest = Gemm1NzRaw; const auto c_extent_lowest = NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{ {
...@@ -680,24 +597,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -680,24 +597,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = const auto b0_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0]; B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
const auto b1_stride_lowest = const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0]; B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
const auto c_stride_lowest = const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous arg.c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1)) c_stride_lowest == 1))
{ {
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return true;
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
} }
// polymorphic // polymorphic
...@@ -706,114 +619,115 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -706,114 +619,115 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument( static auto
MakeArgument(
const ADataType* p_a, const ADataType* p_a,
const BDataType* p_b, const B0DataType* p_b0,
const B1DataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b0,
p_b1, p_b1,
p_c, p_c,
p_acc0_biases, p_acc0_biases,
p_acc1_biases, p_acc1_biases,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b0_gs_ls_ks_lengths,
b_gs_ns_ks_strides, b0_gs_ls_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths b1_gs_ns_ls_lengths,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_ns_ls_strides,
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_ns_lengths,
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_ns_strides,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_ns_strides,
1,
1,
a_element_op, a_element_op,
b_element_op, b0_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op}; c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
// FIXME: constness std::unique_ptr<BaseArgument>
std::unique_ptr<BaseArgument> MakeArgumentPointer( MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b, const void* p_b0,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
p_acc0_biases, // cast in struct Argument p_acc0_biases,
p_acc1_biases, // cast in struct Argument p_acc1_biases,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b0_gs_ls_ks_lengths,
b_gs_ns_ks_strides, b0_gs_ls_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths b1_gs_ns_ls_lengths,
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_ns_ls_strides,
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_ns_lengths,
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_ns_strides,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_biases_gs_ms_ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides, acc1_biases_gs_ms_ns_strides,
1,
1,
a_element_op, a_element_op,
b_element_op, b0_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op);
} }
static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
...@@ -825,25 +739,33 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -825,25 +739,33 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle" str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << LPerBlock << ", "
<< KPerBlock << ", " << K0PerBlock << ", "
<< AK1 << ", " << K1 << ", "
<< BK1 << ", " << MPerBlock << ", "
<< NPerWMMA << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << NPerBlock << ", "
<< Gemm1KPerBlock << ", " << L0PerBlock << ", "
<< B1K1 << ", " << L1
<< getGemmSpecializationString(GemmSpec) << ", " << ">"
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << " NumPrefetch: "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", " << NumPrefetch << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", " << "LoopScheduler: "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", " << LoopSchedToString[LoopSched] << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">"; << "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -20,71 +20,106 @@ namespace ck { ...@@ -20,71 +20,106 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB0,
typename FloatB1,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_K0_N_K1, typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_BL0_N_BL1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_wmma( kernel_batched_gemm_softmax_gemm_wmma_cshuffle(
const FloatA* __restrict__ p_a_grid, const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b0_grid, const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_K0_N_K1 b0_grid_desc_k0_l_k1, const B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1 b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const B0ElementwiseOperation b0_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, const index_t num_blocks_per_batch =
p_b0_grid, __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
p_c_grid, const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_ak0_m_ak1,
b0_grid_desc_k0_l_k1, b0_grid_desc_bk0_l_bk1,
b1_grid_desc_l0_n_l1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op, a_element_op,
b_element_op, b0_element_op,
acc_element_op,
b1_element_op,
c_element_op, c_element_op,
c0_matrix_mask,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b0_grid; ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b0_grid_desc_k0_l_k1; ignore = b0_grid_desc_bk0_l_bk1;
ignore = b1_grid_desc_l0_n_l1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b0_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx1100__))
} }
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] // Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] // Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <index_t BlockSize, template <typename FloatA,
typename FloatA,
typename FloatB0, typename FloatB0,
typename FloatAcc0,
typename FloatB1, typename FloatB1,
typename FloatAcc, typename FloatAcc1,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -93,26 +128,24 @@ template <index_t BlockSize, ...@@ -93,26 +128,24 @@ template <index_t BlockSize,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename AGridDesc_AK0_M_AK1,
typename B0GridDesc_K0_L_K1, typename B0GridDesc_BK0_L_BK1,
typename B1GridDesc_L0_N_L1, typename B1GridDesc_BL0_N_BL1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
index_t Gemm0MPerBlock, index_t MPerBlock,
index_t Gemm0LPerBlock, index_t LPerBlock,
index_t Gemm0K0PerBlock, index_t K0PerBlock, // K0 * K1Value = Gemm0 GEMM_K Dim
index_t Gemm0K1Value, index_t K1Value,
index_t Gemm0MPerWmma, index_t NPerBlock,
index_t Gemm0LPerWmma, index_t L0PerBlock,
index_t Gemm0MRepeat, index_t L1Value,
index_t Gemm0LRepeat, index_t MPerWmma,
index_t Gemm1MPerBlock, index_t LPerWmma,
index_t Gemm1NPerBlock, index_t NPerWmma,
index_t Gemm1L0PerBlock, index_t MRepeat,
index_t Gemm1L1Value, index_t LRepeat,
index_t Gemm1MPerWmma, index_t NRepeat,
index_t Gemm1NPerWmma, index_t BlockSize,
index_t Gemm1MRepeat,
index_t Gemm1NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -141,6 +174,8 @@ template <index_t BlockSize, ...@@ -141,6 +174,8 @@ template <index_t BlockSize,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool PadN,
bool MaskOutUpperTriangle,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
...@@ -155,57 +190,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -155,57 +190,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1Value should be Number<...>
static constexpr auto K1 = Number<Gemm0K1Value>{}; static constexpr auto AK0 = Number<K0PerBlock>{};
static constexpr auto N1 = Number<Gemm1N1Value>{}; static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<K0PerBlock>{};
static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto L0 = Number<L0PerBlock>{};
static constexpr auto L1 = Number<L1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerWmma * MRepeat);
static constexpr auto Gemm0LWaves = L0PerBlock * L1Value / (LPerWmma * LRepeat);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { return make_naive_tensor_descriptor(
if constexpr(ABlockLdsExtraM) make_tuple(AK0, Number<MPerBlock>{}, AK1),
{ make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
return a_block_desc_k0perblock_mperblock_k1;
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() __host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1()
{ {
constexpr auto max_lds_align = K1;
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { return make_naive_tensor_descriptor(
if constexpr(BBlockLdsExtraN) make_tuple(BK0, Number<LPerBlock>{}, BK1),
{ make_tuple(Number<LPerBlock + B0BlockLdsExtraN>{} * BK1, BK1, I1));
return make_naive_tensor_descriptor( }
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
return b_block_desc_k0perblock_nperblock_k1; __host__ __device__ static constexpr auto GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(L0, Number<NPerBlock>{}, L1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * L1, L1, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -228,55 +250,68 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -228,55 +250,68 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0perblock_mperblock_k1 = const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) +
GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0));
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple( const index_t gemm1_bytes_end =
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatB1);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatAcc0);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return (a_block_space_size_aligned * sizeof(FloatA) + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
b_block_space_size_aligned * sizeof(FloatB));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_K0_N_K1& b0_grid_desc_k0_l_k1, const B0GridDesc_BK0_L_BK1& b0_grid_desc_bk0_l_bk1,
const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerWmma)) == 0, (LPerBlock % (LPerWmma * LRepeat)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b0_grid_desc_k0_l_k1.GetLength(I1); const auto L = b0_grid_desc_bk0_l_bk1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto N = b1_grid_desc_l0_n_l1.GetLength(I1);
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && const auto KPerBlock = K0PerBlock * K1Value;
K0 == b0_grid_desc_k0_l_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
K1 == b0_grid_desc_k0_l_k1.GetLength(I2))) {
return false; return false;
}
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 &&
N % NPerBlock == 0))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) // check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false; return false;
}
// check gridwise gemm pipeline // check gemm1 gridwise gemm pipeline
const auto num_k_loop = K0 / K0PerBlock; if(!(LPerBlock % (L0PerBlock * L1Value) == 0))
{
return false;
}
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) const auto num_gemm1_k_inner_loop = LPerBlock / (L0PerBlock * L1Value);
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{ {
return false; return false;
} }
...@@ -292,7 +327,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -292,7 +327,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1Value);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -328,6 +363,42 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -328,6 +363,42 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b0_block_desc_bk0_l_bk1 =
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
static constexpr auto b1_block_desc_bl0_n_bl1 =
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), L1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b0_block_desc_bk0_l_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bl0_n_bl1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, typename C0MatrixMask, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop, typename C0MatrixMask, typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
...@@ -335,9 +406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -335,9 +406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const FloatB1* __restrict__ p_b1_grid, const FloatB1* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1& a_grid_desc_k0_m_k1,
const B0GridDesc_K0_L_K1& b0_grid_desc_k0_l_k1, const B0GridDesc_BK0_L_BK1& b0_grid_desc_k0_l_k1,
const B1GridDesc_L0_N_L1& b1_grid_desc_l0_n_l1, const B1GridDesc_BL0_N_BL1& b1_grid_desc_l0_n_l1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -380,9 +451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -380,9 +451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
constexpr auto max_lds_align = K1; // constexpr auto max_lds_align = K1Value;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -390,7 +461,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -390,7 +461,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename SrcElementwiseOperation, */ AElementwiseOperation, /* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, /* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, /* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>, /* typename BlockSliceLengths, */ Sequence<AK0, MPerBlock, AK1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatA, /* typename SrcData, */ FloatA,
...@@ -415,134 +486,177 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -415,134 +486,177 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b0_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, B0ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_K0_N_K1, B0BlockTransferThreadClusterLengths_K0_L_K1,
BBlockTransferThreadClusterArrangeOrder, B0BlockTransferThreadClusterArrangeOrder,
FloatB, FloatB0,
FloatB, FloatB0,
decltype(b0_grid_desc_k0_l_k1), decltype(b0_grid_desc_k0_l_k1),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b0_block_desc_k0perblock_lperblock_k1),
BBlockTransferSrcAccessOrder, B0BlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, B0BlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, B0BlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, B0BlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, B0ThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b0_grid_desc_k0_l_k1, b0_grid_desc_k0_l_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, 0, 0),
b_element_op, b0_element_op,
b_block_desc_k0perblock_nperblock_k1, b0_block_desc_k0perblock_lperblock_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
/*******************************************************************************/ /*******************************************************************************/
// Gemm0 // Gemm0
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK);
auto blockwise_gemm0 = auto blockwise_gemm0 =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
FloatA, FloatA,
FloatB, FloatB0,
FloatAcc, FloatAcc0,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b0_block_desc_k0perblock_lperblock_k1),
MPerWmma, MPerWmma,
NPerWmma, LPerWmma,
MRepeat, MRepeat,
NRepeat, LRepeat,
KPack>{}; KPack>{};
// Prepare Register for A*B0 matrix // Prepare Register for A*B0 matrix
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lrepeat, lsubgroup)),
make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)),
make_pass_through_transform(laccvgprs)),
make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
/*******************************************************************************/ /*******************************************************************************/
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared) + SharedMemTrait::a_block_space_offset,
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
auto b0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB0*>(p_shared) + SharedMemTrait::b0_block_space_offset,
b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0); const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0); const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0);
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
/*******************************************************************************/ /*******************************************************************************/
// softmax // softmax
/*******************************************************************************/ /*******************************************************************************/
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatAcc*>(p_shared), math::integer_least_multiple(BlockSize, max_lds_align)); auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// get acc0 8D thread cluster static_cast<FloatAcc0*>(p_shared) + SharedMemTrait::reduction_space_offset,
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = SharedMemTrait::reduction_space_size_aligned);
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / // get acc0 7D thread cluster
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() /
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3); constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4); constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5); constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3);
constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6); constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4);
constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7); constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5);
constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6);
// get acc0 thread map // get acc0 thread map
constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)), make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)),
make_pass_through_transform(I1)), make_pass_through_transform(I1)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor( constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor(
make_tuple( make_tuple(
make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))), make_merge_transform(
make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor = const auto threadid_to_l_n_thread_cluster_adaptor =
chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice // get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed(
make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize, auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc, FloatAcc0,
decltype(threadid_to_m_n_thread_cluster_adaptor), decltype(threadid_to_l_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_l),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_l)>{};
// Initialize running sum and max of exponentiating row vectors // Initialize running sum and max of exponentiating row vectors
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0; running_sum = 0;
running_sum_new = 0; running_sum_new = 0;
running_max = NumericLimits<FloatGemmAcc>::Lowest(); running_max = NumericLimits<FloatAcc0>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest(); running_max_new = NumericLimits<FloatAcc0>::Lowest();
/*******************************************************************************/ /*******************************************************************************/
// set up Gemm1 // set up Gemm1
/*******************************************************************************/ /*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_L0PerBlock_NPerBlock_L1(); constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(L0PerBlock, 0, 0);
// A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
Number<L0PerBlock * L1Value / laccvgprs>{},
Number<mrepeat * mwave * mthreadpersubgroup>{},
Number<laccvgprs>{}); // Data duplicated dimension
constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0];
constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1];
constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2];
// A1 has duplicated data
constexpr auto A1ThreadDuplicatedDim = I2 * A1ThreadSliceL1;
constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor(
make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadDuplicatedDim),
make_tuple(A1ThreadSliceMPerBlock * A1ThreadDuplicatedDim, A1ThreadDuplicatedDim, I1));
// A1 matrix blockwise copy // A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow<
FloatAcc, FloatAcc0,
FloatA, FloatA,
decltype(acc_thread_desc_k0_m_k1), decltype(acc0_thread_desc_l0perblock_mperblock_l1),
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_l0perblock_mperblock_l1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>, Sequence<A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1>,
Sequence<1, 0, 2>, Sequence<0, 1, 2>,
2, 2,
n4, laccvgprs,
// dst Rowlane // dst Rowlane
// 0x76543210 0xfedcba98 // 0x76543210 0xfedcba98
// src Rowlane // src Rowlane
...@@ -551,68 +665,77 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -551,68 +665,77 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, B0ElementwiseOperation,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>, Sequence<L0, NPerBlock, L1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB1,
FloatAB, FloatB1,
decltype(b1_grid_desc_bk0_n_bk1), decltype(b1_grid_desc_l0_n_l1),
decltype(b1_block_desc_bk0_n_bk1), decltype(b1_block_desc_l0perblock_nperblock_l1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim, B1BlockTransferSrcVectorDim,
2, 2,
B1BlockTransferSrcScalarPerVector, B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1, B1BlockTransferDstScalarPerVector_L1,
1, 1,
1, 1,
B1ThreadTransferSrcResetCoordinateAfterRun, B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
b1_grid_desc_bk0_n_bk1, b1_grid_desc_l0_n_l1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b1_element_op, b1_element_op,
b1_block_desc_bk0_n_bk1, b1_block_desc_l0perblock_nperblock_l1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(a1_thread_desc_k0_m_k1.GetElementSpaceSize()); auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared), b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatB1*>(p_shared)+ SharedMemTrait::b1_block_space_offset,
b1_block_desc_l0perblock_nperblock_l1.GetElementSpaceSize());
auto blockwise_gemm1 = auto blockwise_gemm1 =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
FloatA, FloatA,
FloatB, FloatB1,
FloatAcc, FloatAcc1,
decltype(a1_thread_desc_k0perblock_mperblock_k1), decltype(a1_thread_desc_l0perblock_mperblock_l1),
decltype(b1_block_desc_k0perblock_nperblock_k1), decltype(b1_block_desc_l0perblock_nperblock_l1),
MPerWmma, MPerWmma,
NPerWmma, NPerWmma,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{make_tuple(0, 0, 0, 0, 0)}; KPack>{};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_k_block_outer_loop = b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (L0PerBlock * L1Value);
// Initialize C // Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc, acc1_thread_buf.Size(), true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf;
c_thread_buf.Clear(); c_thread_buf.Clear();
/*******************************************************************************/ /*******************************************************************************/
// Flash Attention // Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022). // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t gemm1_k_block_outer_index = 0; index_t gemm1_l_block_outer_index = 0;
// Outer loop, along GEMM_L // Outer loop, along GEMM_L
// Inner loop, along GEMM_K // Inner loop, along GEMM_K
do{ do{
auto l_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock))
{
continue;
}
// gemm0 start, A-B swaped // gemm0 start, A-B swaped
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0perblock_mperblock_k1, a_block_desc_k0perblock_mperblock_k1,
a_blockwise_copy, a_blockwise_copy,
...@@ -620,33 +743,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -620,33 +743,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b0_grid_desc_k0_l_k1, b0_grid_desc_k0_l_k1,
b_block_desc_k0perblock_nperblock_k1, b0_block_desc_k0perblock_lperblock_k1,
b_blockwise_copy, b0_blockwise_copy,
b0_grid_buf, b0_grid_buf,
b_block_buf, b0_block_buf,
b_block_slice_copy_step, b0_block_slice_copy_step,
blockwise_gemm, blockwise_gemm0,
acc_thread_buf, acc0_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
// do MNK padding or upper triangular masking // do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN) if constexpr(MaskOutUpperTriangle || PadN)
{ {
// 8d thread_desc in thread scope // 7d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
// 8d block_desc in block scope // 7d block_desc in block scope
constexpr auto c_block_lengths = constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths();
constexpr auto M0 = c_block_lengths[I0]; constexpr auto MREPEAT = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1]; constexpr auto MWAVE = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2]; constexpr auto MTHREADSubGroup = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3]; constexpr auto LREPEAT = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4]; constexpr auto LWAVE = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5]; constexpr auto LSUBGROUP = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6]; constexpr auto LACCVGPRS = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear // works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index // index as well as n-d index
...@@ -656,36 +778,34 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -656,36 +778,34 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type, typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved false>; // SnakeCurved
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); Number<0>{}, Number<0>{});
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)), make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))), make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local = auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto l_global = l_local + l_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) if(c0_matrix_mask.IsMaskedElement(m_global, l_global))
{ {
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity(); acc0_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
} }
else else
{ {
acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]);
} }
}); });
} }
else else
{ static_for<0, acc_thread_buf.Size(), 1>{}( { static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
} }
...@@ -697,7 +817,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -697,7 +817,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc_thread_buf, workspace_buf); blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
...@@ -717,79 +837,80 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -717,79 +837,80 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
acc1_thread_buf.Clear(); acc1_thread_buf.Clear();
// preload data into LDS // preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1,
b1_block_slice_copy_step); b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf);
// main body // main body
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_l_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) {
// Data cast from FloatAcc to FloatA happen here // Data cast from FloatAcc0 to FloatA happen here
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0), make_tuple(Number<i * A1ThreadSliceL0PerBlock>{}, I0, I0),
acc_thread_buf, acc0_thread_buf,
a1_thread_desc_k0_m_k1, a1_thread_desc_l0perblock_mperblock_l1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); b1_blockwise_copy.RunRead(b1_grid_desc_l0_n_l1, b1_grid_buf);
block_sync_lds(); block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds(); block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_l0_n_l1,
b1_block_slice_copy_step); b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); b1_blockwise_copy.RunWrite(b1_block_desc_l0perblock_nperblock_l1, b1_block_buf);
}); });
} }
// tail // tail
{ {
a1_blockwise_copy.Run( a1_blockwise_copy.Run(
acc_thread_desc_k0_m_k1, acc0_thread_desc_l0perblock_mperblock_l1,
make_tuple( make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0),
acc_thread_buf, acc0_thread_buf,
a1_thread_desc_k0_m_k1, a1_thread_desc_l0perblock_mperblock_l1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
block_sync_lds(); block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
} }
} // end gemm1 } // end gemm1
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4)); make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{}; auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V FloatAcc1 acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O FloatAcc1 c = c_thread_buf[I]; // O
FloatGemmAcc c_new = FloatAcc1 c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) / math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; running_sum_new[iM];
...@@ -798,26 +919,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -798,26 +919,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}); });
}); });
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_reset_copy_step); // rewind K a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_k0_l_k1,
b_block_reset_copy_step); // rewind K and step N b0_block_reset_copy_step); // rewind K and step N
// update before next j iteration // update before next j iteration
running_max = running_max_new; running_max = running_max_new;
running_sum = running_sum_new; running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
}while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop);
/*******************************************************************************/ /*******************************************************************************/
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need // This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
...@@ -852,7 +973,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -852,7 +973,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); const auto c_thread_mtx_on_block = blockwise_gemm0.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
...@@ -877,7 +998,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -877,7 +998,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc1,
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
......
...@@ -1313,8 +1313,8 @@ template <typename SrcData, ...@@ -1313,8 +1313,8 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
index_t LowEightRowlaneIdx, uint32_t LowEightRowlaneIdx,
index_t HighEightRowLaneIdx, uint32_t HighEightRowLaneIdx,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
......
...@@ -369,7 +369,7 @@ struct WmmaGemm ...@@ -369,7 +369,7 @@ struct WmmaGemm
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>; using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>; using CIndex3D = MultiIndex<3>;
__host__ __device__ constexpr WmmaGemm() __host__ __device__ constexpr WmmaGemm()
{ {
...@@ -421,6 +421,46 @@ struct WmmaGemm ...@@ -421,6 +421,46 @@ struct WmmaGemm
Sequence<5>{})); Sequence<5>{}));
} }
// Transposed WMMA Output C' = B' * A'
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
}
__device__ static constexpr index_t GetRegSizePerWmma() __device__ static constexpr index_t GetRegSizePerWmma()
{ {
return wmma_instr.num_acc_vgprs_per_wave; return wmma_instr.num_acc_vgprs_per_wave;
...@@ -493,6 +533,14 @@ struct WmmaGemm ...@@ -493,6 +533,14 @@ struct WmmaGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
__device__ static CIndex3D GetBeginOfThreadBlk3D()
{
index_t n_offset = GetLaneIdUnderSubGroup();
index_t m_offset = GetSubGroupId();
return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
}
static constexpr auto wmma = static constexpr auto wmma =
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{}; WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = wmma.selected_wmma; static constexpr auto wmma_instr = wmma.selected_wmma;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment