Commit 228e9cd1 authored by danyao12's avatar danyao12
Browse files

add Gemm2KPerBlock template

parent 64b5f20f
......@@ -255,6 +255,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
......@@ -671,6 +672,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm2KPerBlock,
AK1,
BK1,
B1K1,
......@@ -1317,6 +1319,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -47,6 +47,7 @@ template <typename InputDataType,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
......@@ -807,13 +808,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2>
struct Gemm2Params_
struct Gemm2Params
{
static constexpr index_t Gemm2_M = MPerBlock; // 64
static constexpr index_t Gemm2_K = NPerBlock; // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128
static constexpr index_t Sum_K = Sum_K_;
static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1;
......@@ -836,13 +836,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1;
constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1;
constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
......@@ -863,7 +863,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
};
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment