Commit 3b395570 authored by danyao12's avatar danyao12
Browse files

add Gemm2KPerBlock template for split kernels

parent 228e9cd1
...@@ -306,6 +306,7 @@ template <index_t NumDimG, ...@@ -306,6 +306,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1, index_t B1K1,
...@@ -761,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -761,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
KPerBlock, KPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm1KPerBlock, Gemm1KPerBlock,
Gemm2KPerBlock,
AK1, AK1,
BK1, BK1,
B1K1, B1K1,
...@@ -1457,6 +1459,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1457,6 +1459,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -48,6 +48,7 @@ template <typename InputDataType, ...@@ -48,6 +48,7 @@ template <typename InputDataType,
index_t KPerBlock, index_t KPerBlock,
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value, index_t AK1Value,
index_t BK1Value, index_t BK1Value,
index_t B1K1Value, index_t B1K1Value,
...@@ -786,13 +787,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -786,13 +787,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n // Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2> struct Gemm2Params
struct Gemm2Params_
{ {
static constexpr index_t Gemm2_M = MPerBlock; // 64 static constexpr index_t Gemm2_M = MPerBlock; // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128 static constexpr index_t Gemm2_K = NPerBlock; // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128 static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128
static constexpr index_t Sum_K = Sum_K_; static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // dS will be row-major static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1; static constexpr index_t A_K0 = Sum_K / A_K1;
...@@ -815,13 +815,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -815,13 +815,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2() __host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{ {
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl // perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1; constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl; constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves; constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave; constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl // perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1; constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl; constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves; constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave; constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
...@@ -842,7 +842,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -842,7 +842,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using ABlockSliceLengths_M0_K0_M1_K1 = using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2) decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
}; };
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm> template <typename Gemm2Params, typename ASrcBlockwiseGemm>
...@@ -1033,14 +1032,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1033,14 +1032,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BBlockwiseCopy = using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType, ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType, GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3), decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3), decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3, BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>, Sequence<0, 1, 2, 3, 4, 5, 6>,
6, 6,
1, 1,
1, 1,
true>; true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0); static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
...@@ -1049,20 +1048,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1049,20 +1048,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
GemmDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1), decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)), decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)),
decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)), decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
MPerBlock, MPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm2Params::Sum_K, Gemm2Params::Sum_K,
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
Gemm2Params::GemmMRepeat, Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat, Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack, Gemm2Params::GemmKPack,
true, // TransposeC true, // TransposeC
Gemm2Params::GemmKPack * Gemm2Params::GemmKPack *
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{} XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{}
...@@ -1343,7 +1342,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1343,7 +1342,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof(GemmDataType); sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned + const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) * SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType); sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned + const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) * SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType); sizeof(GemmDataType);
...@@ -1353,11 +1352,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1353,11 +1352,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, return math::max(
gemm1_bytes_end, gemm0_bytes_end, gemm1_bytes_end, gemm2_bytes_end, gemm3_bytes_end, c_block_bytes_end);
gemm2_bytes_end,
gemm3_bytes_end,
c_block_bytes_end);
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment