Commit 3b395570 authored by danyao12's avatar danyao12
Browse files

add Gemm2KPerBlock template for split kernels

parent 228e9cd1
......@@ -306,6 +306,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
......@@ -761,6 +762,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm2KPerBlock,
AK1,
BK1,
B1K1,
......@@ -1457,6 +1459,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -48,6 +48,7 @@ template <typename InputDataType,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
......@@ -786,13 +787,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2>
struct Gemm2Params_
struct Gemm2Params
{
static constexpr index_t Gemm2_M = MPerBlock; // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128
static constexpr index_t Sum_K = Sum_K_;
static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1;
......@@ -815,13 +815,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1;
constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1;
constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
......@@ -842,7 +842,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
};
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
......@@ -1033,14 +1032,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
......@@ -1049,20 +1048,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_n_k1),
decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)),
decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
MPerBlock,
Gemm1NPerBlock,
Gemm2Params::Sum_K,
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
true, // TransposeC
Gemm2Params::GemmKPack *
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{}
......@@ -1343,7 +1342,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
......@@ -1353,11 +1352,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
c_block_bytes_end);
return math::max(
gemm0_bytes_end, gemm1_bytes_end, gemm2_bytes_end, gemm3_bytes_end, c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment