Commit 228e9cd1 authored by danyao12's avatar danyao12
Browse files

add Gemm2KPerBlock template

parent 64b5f20f
...@@ -255,6 +255,7 @@ template <index_t NumDimG, ...@@ -255,6 +255,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1, index_t AK1,
index_t BK1, index_t BK1,
index_t B1K1, index_t B1K1,
...@@ -671,6 +672,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -671,6 +672,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
KPerBlock, KPerBlock,
Gemm1NPerBlock, Gemm1NPerBlock,
Gemm1KPerBlock, Gemm1KPerBlock,
Gemm2KPerBlock,
AK1, AK1,
BK1, BK1,
B1K1, B1K1,
...@@ -1317,6 +1319,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1317,6 +1319,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", " << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
...@@ -47,6 +47,7 @@ template <typename InputDataType, ...@@ -47,6 +47,7 @@ template <typename InputDataType,
index_t KPerBlock, index_t KPerBlock,
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value, index_t AK1Value,
index_t BK1Value, index_t BK1Value,
index_t B1K1Value, index_t B1K1Value,
...@@ -807,13 +808,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -807,13 +808,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n // Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2> struct Gemm2Params
struct Gemm2Params_
{ {
static constexpr index_t Gemm2_M = MPerBlock; // 64 static constexpr index_t Gemm2_M = MPerBlock; // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128 static constexpr index_t Gemm2_K = NPerBlock; // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128 static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128
static constexpr index_t Sum_K = Sum_K_; static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // dS will be row-major static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1; static constexpr index_t A_K0 = Sum_K / A_K1;
...@@ -836,13 +836,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -836,13 +836,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2() __host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{ {
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl // perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1; constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl; constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves; constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave; constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl // perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1; constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl; constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves; constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave; constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
...@@ -863,7 +863,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -863,7 +863,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using ABlockSliceLengths_M0_K0_M1_K1 = using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2) decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
}; };
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm> template <typename Gemm2Params, typename ASrcBlockwiseGemm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment