Commit 63e3f3c4 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop-bwdopt-bias' into mha-train-develop-grad-bias

parents 592b0649 db579ac9
......@@ -262,29 +262,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
if(O != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1)))
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
O % Gemm1NPerBlock == 0))
{
return false;
}
......@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
};
// dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
......@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
......@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
......
......@@ -113,11 +113,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
......@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
......@@ -307,29 +307,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
if(O != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1)))
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
O % Gemm1NPerBlock == 0))
{
return false;
}
......@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
......@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2
{
private:
......@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
......@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2);
const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1,
BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
......@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
......@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up S Gemm (type 4 rcc)
//
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
......
......@@ -261,29 +261,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
if(O != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1)))
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
O % Gemm1NPerBlock == 0))
{
return false;
}
......@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
// dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
......@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1651,6 +1648,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
......
......@@ -112,11 +112,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
......@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
......@@ -306,29 +306,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
if(O != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1)))
if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
O % Gemm1NPerBlock == 0))
{
return false;
}
......@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
......@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2
{
private:
......@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
......@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2);
const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1,
BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
......@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
......@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1520,8 +1511,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
D0Operator::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Operator::template TypeTransform<D0DataType>::Size;
......@@ -1697,6 +1687,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1733,6 +1725,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up S Gemm (type 4 rcc)
//
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1838,7 +1832,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment