Unverified Commit 21ef37b4 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #889 from ROCmSoftwarePlatform/mha-train-develop-bwdopt-bias

Mha train develop bwdopt bias
parents 1f04cd2b db579ac9
...@@ -39,7 +39,7 @@ template <typename InputDataType, ...@@ -39,7 +39,7 @@ template <typename InputDataType,
typename KGridDesc_N_K, typename KGridDesc_N_K,
typename D0GridDesc_M_N, typename D0GridDesc_M_N,
typename ZGridDesc_M_N, typename ZGridDesc_M_N,
typename VGridDesc_N0_O_N1, typename VGridDesc_O0_N_O1,
typename YGridDesc_M_O, typename YGridDesc_M_O,
typename LSEGridDesc_M, typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -49,6 +49,7 @@ template <typename InputDataType, ...@@ -49,6 +49,7 @@ template <typename InputDataType,
index_t KPerBlock, index_t KPerBlock,
index_t Gemm1NPerBlock, index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock, index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1Value, index_t AK1Value,
index_t BK1Value, index_t BK1Value,
index_t B1K1Value, index_t B1K1Value,
...@@ -124,6 +125,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -124,6 +125,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
static constexpr auto V_K0 = Gemm1NPerBlock / KPerBlock;
static constexpr auto V_N1 = NXdlPerWave;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2 static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time // get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16 static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
...@@ -197,6 +204,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -197,6 +204,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
} }
__host__ __device__ static constexpr auto GetKBlockDescriptor_K0PerBlock_NPerBlock_K1()
{
// K matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(K_K0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3()
{
// V matrix in Vgpr, dst of threadwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<V_K1>{}, I1, I1, Number<V_N1>{}, I1, I1, Number<V_K3>{}));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2> template <typename AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1( __host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2& acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2) const AccThreadDesc_M0_N0_M1_N1_M2_M3_M4_N2& acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2)
...@@ -277,36 +299,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -277,36 +299,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDesc_M_O& y_grid_desc_m_o) const YGridDesc_M_O& y_grid_desc_m_o)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
...@@ -411,22 +433,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -411,22 +433,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
__device__ static auto MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock( __device__ static auto MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1) const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{ {
const auto N0 = v_grid_desc_n0_o_n1.GetLength(I0); const auto O0 = v_grid_desc_o0_n_o1.GetLength(I0);
const auto O = v_grid_desc_n0_o_n1.GetLength(I1); const auto N = v_grid_desc_o0_n_o1.GetLength(I1);
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2); const auto O1 = v_grid_desc_o0_n_o1.GetLength(I2);
const auto N = N0 * N1; const auto O = O0 * O1;
const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto OBlock = O / Gemm1NPerBlock; const auto OBlock = O / Gemm1NPerBlock;
const auto v_grid_desc_n_o = transform_tensor_descriptor( const auto v_grid_desc_n_o = transform_tensor_descriptor(
v_grid_desc_n0_o_n1, v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(O), make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(N0, N1))), make_merge_transform_v3_division_mod(make_tuple(O0, O1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
v_grid_desc_n_o, v_grid_desc_n_o,
...@@ -438,14 +460,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -438,14 +460,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__device__ static auto MakeQGradGridDesc_M_K(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1) __device__ static auto MakeQGradGridDesc_M_K(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1)
{ {
const auto K_K0 = q_grid_desc_k0_m_k1.GetLength(I0); const auto K0_ = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M_ = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K_K1 = q_grid_desc_k0_m_k1.GetLength(I2); const auto K1_ = q_grid_desc_k0_m_k1.GetLength(I2);
return transform_tensor_descriptor( return transform_tensor_descriptor(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M), make_tuple(make_pass_through_transform(M_),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))), make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -467,16 +489,120 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -467,16 +489,120 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcc) // K / V
struct GemmBlockwiseCopy
{
__device__ static auto
MakeVGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1)
{
const auto K0_ = v_grid_desc_o0_n_o1.GetLength(I0);
const auto N_ = v_grid_desc_o0_n_o1.GetLength(I1);
const auto K1_ = v_grid_desc_o0_n_o1.GetLength(I2);
constexpr auto V_N3 = NPerXdl;
constexpr auto V_N2 = Gemm0NWaves;
const auto V_N0 = N_ / NPerBlock;
const auto v_grid_desc_n_k = transform_tensor_descriptor(
v_grid_desc_o0_n_o1,
make_tuple(make_pass_through_transform(N_),
make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
v_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(V_N0, V_N1, V_N2, V_N3)),
make_unmerge_transform(make_tuple(V_K0, V_K1, V_K2, V_K3))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<3, 4, 5, 6>{}, Sequence<0, 1, 2, 7>{}));
}
// K matrix in LDS, dst of blockwise copy
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// V matrix in Vgpr, dst of threadwise copy
static constexpr auto v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename GridDesc_K0_N_K1>
using KBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K_K0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_K1_k2_N0_N1_N2_N3_K3>
using VBlockwiseCopy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
GemmDataType,
GridDesc_K0_K1_k2_N0_N1_N2_N3_K3,
decltype(v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLengths()),
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
BK1,
1,
true /* ResetCoordAfterRun */>;
static constexpr auto VBlockBufferSize = V_K0;
static constexpr auto v_block_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
};
// dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0 struct Gemm0
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
static constexpr auto b_block_desc_bk0_n_bk1 = __host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
{
// b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 to b_thread_desc_k0_n_k1
// k0_k1_k2 -> k0
// n0_n1_n2_n3 -> n
// k3 -> k1
const auto k0 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0);
const auto k1 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I1);
const auto k2 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I2);
const auto n0 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I3);
const auto n1 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I4);
const auto n2 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I5);
const auto n3 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I6);
const auto k3 = b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I7);
return transform_tensor_descriptor(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(k0, k1, k2)),
make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_pass_through_transform(k3)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -492,9 +618,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -492,9 +618,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{ {
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, 1, 1>(
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{}); BBlockDesc_BK0_N_BK1{});
} }
...@@ -523,31 +647,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -523,31 +647,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk); static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
...@@ -556,9 +655,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -556,9 +655,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_src_thread_desc_k0_n_k1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_src_thread_desc_k0_n_k1)),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -566,10 +665,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -566,10 +665,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
NPerXdl, NPerXdl,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack,
false,
KPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, KPack, false>{}.K0PerXdlops,
KPack>; KPack>;
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
}; };
// dV / dK Gemm (type 2 rrr) // dV / dK Gemm (type 2 rrr)
...@@ -707,48 +808,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -707,48 +808,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
// Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n // Describes tuning parameter for C2_m_n = A2_m_k * B2_k_n
template <index_t Sum_K_ = NPerXdl * 2> struct Gemm2Params
struct Gemm2Params_
{ {
static constexpr index_t Gemm2_M = MPerBlock; static constexpr index_t Gemm2_M = MPerBlock; // 64
static constexpr index_t Gemm2_K = NPerBlock; static constexpr index_t Gemm2_K = NPerBlock; // 128
static constexpr index_t Gemm2_N = Gemm1NPerBlock; static constexpr index_t Gemm2_N = Gemm1NPerBlock; // 128
static constexpr index_t Sum_K = Sum_K_; static constexpr index_t Sum_K = Gemm2KPerBlock;
static constexpr index_t A_K1 = 8; // P will be row-major static constexpr index_t A_K1 = 8; // dS will be row-major
static constexpr index_t A_K0 = Sum_K / A_K1; static constexpr index_t A_K0 = Sum_K / A_K1;
static constexpr index_t A_LdsPad = 0; // how many multiples of K1 per M * K1 elements static constexpr index_t A_LdsPad = 0; // how many multiples of K1 per M * K1 elements
static constexpr index_t B_K1 = B1K1; // dY assumed row-major, typically =2 for fp16
static constexpr index_t B_K0 = Sum_K / B_K1;
static constexpr index_t B_LdsPad = 0; // how many multiples of K1 per N * K1 elements
static_assert(Sum_K % NPerXdl == 0, ""); static_assert(Sum_K % NPerXdl == 0, "");
static constexpr index_t BSrcVectorDim = 1; // Gemm2_N dimension static constexpr index_t GemmNWave = Gemm2_N / Gemm2NXdlPerWave / NPerXdl; // 1 // 2
static constexpr index_t BSrcScalarPerVector = 4; static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; // 4 // 2
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmNWave = Gemm2_N / Gemm2NXdlPerWave / NPerXdl; static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave; static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk); static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>; static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
using BThreadClusterLengths = static constexpr index_t B_K0 = GemmKLoop; // 2
Sequence<BlockSize / (Gemm2_N / BSrcScalarPerVector), Gemm2_N / BSrcScalarPerVector, 1>;
using BThreadClusterArrangeOrder = Sequence<0, 2, 1>;
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2() __host__ __device__ static constexpr auto GetABlockSliceLengths_M0_K0_M1_K1_M2_K2()
{ {
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl // perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t k = Gemm2Params::Sum_K - 1; constexpr index_t k = Sum_K - 1;
constexpr index_t k2 = k % NPerXdl; constexpr index_t k2 = k % NPerXdl;
constexpr index_t k1 = k / NPerXdl % Gemm0NWaves; constexpr index_t k1 = k / NPerXdl % Gemm0NWaves;
constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave; constexpr index_t k0 = k / NPerXdl / Gemm0NWaves % NXdlPerWave;
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl // perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params::Gemm2_M - 1; constexpr index_t m = Gemm2_M - 1;
constexpr index_t m2 = m % MPerXdl; constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves; constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave; constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
...@@ -769,10 +863,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -769,10 +863,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using ABlockSliceLengths_M0_K0_M1_K1 = using ABlockSliceLengths_M0_K0_M1_K1 =
decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2) decltype(GetABlockSliceLengths_M0_K0_M1_K1()); //(2, 1, 1, 2) //(4, 1, 1, 2)
}; };
using Gemm2Params = Gemm2Params_<>; // tune later
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm> template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2 struct Gemm2
{ {
private: private:
...@@ -795,8 +888,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -795,8 +888,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>(); static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy template <typename ABlockDesc_K0_M_K1>
static constexpr auto b_block_desc_k0_n_k1 = GetB2BlockDescriptor_K0_N_K1<Gemm2Params>(); __host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmMRepeat,
Gemm2Params::GemmMWave,
MPerXdl>(ABlockDesc_K0_M_K1{});
}
template <typename BBlockDesc_K0_N_K1>
__host__ __device__ static constexpr auto
MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_K0_N_K1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm2Params::GemmNRepeat, 1, 1>(
BBlockDesc_K0_N_K1{});
}
__host__ __device__ static constexpr auto MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2() __host__ __device__ static constexpr auto MakeABlockDesc_M0_K0_M1_K1_M2_M3_M4_K2()
{ {
...@@ -875,49 +982,112 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -875,49 +982,112 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
template <typename GridDesc_K0_N_K1> __host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< {
ThisThreadBlock, const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
tensor_operation::element_wise::PassThrough, const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
tensor_operation::element_wise::PassThrough, const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
InMemoryDataOperationEnum::Set,
typename Gemm2Params::BBlockSliceLengths, constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
typename Gemm2Params::BThreadClusterLengths, BSrcBlockDesc_N0_K_N1{},
typename Gemm2Params::BThreadClusterArrangeOrder, make_tuple(make_merge_transform_v3_division_mod(
InputDataType, make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_block_desc_n_k,
make_tuple(
make_unmerge_transform(make_tuple(Gemm2Params::GemmNRepeat,
Gemm2Params::GemmNWave,
NPerXdl)), //(1, 1, 32) //(1, 2, 32)
make_unmerge_transform(
make_tuple(Gemm2Params::B_K0,
Gemm2Params::B_K1,
Gemm2Params::B_K2,
Gemm2Params::B_K3))), //(2, 4, 2, 8) //(2, 4, 2, 8)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}));
}
static constexpr auto b_block_desc_n0_n1_n2_k0_k1_k2_k3 =
MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3();
using BThreadSlice_N0_N1_N2_K0_K1_K2_K3 =
Sequence<Gemm2Params::GemmNRepeat, 1, 1, 1, Gemm2Params::B_K1, 1, Gemm2Params::B_K3>;
static constexpr auto b_thread_desc_n0_n1_n2_k0_k1_k2_k3 =
make_naive_tensor_descriptor_packed(make_tuple(Number<Gemm2Params::GemmNRepeat>{},
I1,
I1,
I1,
Number<Gemm2Params::B_K1>{},
I1,
Number<Gemm2Params::B_K3>{}));
__host__ __device__ static constexpr auto MakeBThreadDesc_K0_N_K1()
{
constexpr auto b_thread_desc_n_k = transform_tensor_descriptor(
b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<Gemm2Params::GemmNRepeat>{}, I1, I1)),
make_merge_transform_v3_division_mod(make_tuple(
I1, Number<Gemm2Params::B_K1>{}, I1, Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_thread_desc_n_k,
make_tuple(make_pass_through_transform(Number<Gemm2Params::GemmNRepeat>{}),
make_unmerge_transform(make_tuple(Number<Gemm2Params::B_K1>{},
Number<Gemm2Params::B_K3>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
}
static constexpr auto b_thread_desc_k0_n_k1 = MakeBThreadDesc_K0_N_K1();
using BBlockwiseCopy =
ThreadwiseTensorSliceTransfer_v2<GemmDataType,
GemmDataType,
decltype(b_block_desc_n0_n1_n2_k0_k1_k2_k3),
decltype(b_thread_desc_n0_n1_n2_k0_k1_k2_k3),
BThreadSlice_N0_N1_N2_K0_K1_K2_K3,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
1,
true>;
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, 1, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(0, 0, 0, -Gemm2Params::B_K0, 0, 0, 0);
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, FloatGemmAcc,
decltype(b_block_desc_k0_n_k1), decltype(a_block_desc_k0_m_k1),
typename Gemm2Params::BThreadClusterArrangeOrder, // access order == thread order decltype(b_thread_desc_k0_n_k1),
Sequence<1, 0, 2>, decltype(MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_k0_m_k1)),
Gemm2Params::BSrcVectorDim, decltype(MakeGemm2BMmaTileDescriptor_N0_N1_N2_K(b_thread_desc_k0_n_k1)),
2, // DstVectorDim MPerBlock,
Gemm2Params::BSrcScalarPerVector, Gemm1NPerBlock,
Gemm2Params::B_K1, Gemm2Params::Sum_K,
1, MPerXdl,
1, NPerXdl,
true, Gemm2Params::GemmMRepeat,
true, Gemm2Params::GemmNRepeat,
1>; Gemm2Params::GemmKPack,
true, // TransposeC
Gemm2Params::GemmKPack *
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, Gemm2Params::GemmKPack, false>{}
.K0PerXdlops,
Gemm2Params::GemmKPack>;
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
MPerXdl,
NPerXdl,
Gemm2Params::GemmMRepeat,
Gemm2Params::GemmNRepeat,
Gemm2Params::GemmKPack,
true>; // TranspossC
static constexpr auto b_block_slice_copy_step = make_multi_index(Gemm2Params::B_K0, 0, 0);
static constexpr auto c_block_slice_copy_step = static constexpr auto c_block_slice_copy_step =
make_multi_index(-Gemm2Params::GemmMRepeat, 0, 0, 0, 0, 0, 0, 0); make_multi_index(-Gemm2Params::GemmMRepeat, 0, 0, 0, 0, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(-NPerBlock / Gemm2Params::B_K1, 0, 0);
template <typename CGradDesc_M_N> template <typename CGradDesc_M_N>
__host__ __device__ static auto __host__ __device__ static auto
...@@ -964,6 +1134,84 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -964,6 +1134,84 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true>; true>;
}; };
// S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
InputDataType,
GemmDataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>;
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KPerBlock);
static constexpr auto b_block_reset_copy_step = make_multi_index(0, 0, 0, -Gemm1NPerBlock);
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
...@@ -1014,26 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1014,26 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return ygrad_grid_desc_o0_m_o1; return ygrad_grid_desc_o0_m_o1;
} }
template <typename VGridDesc_N0_O_N1_>
__device__ static auto MakeVGridDesc_O0_N_O1(const VGridDesc_N0_O_N1_& v_grid_desc_n0_o_n1)
{
const auto N0 = v_grid_desc_n0_o_n1.GetLength(I0);
const auto O = v_grid_desc_n0_o_n1.GetLength(I1);
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2);
constexpr auto V_O1 = BK1;
const auto V_O0 = O / V_O1;
const auto v_grid_desc_o0_n_o1 = transform_tensor_descriptor(
v_grid_desc_n0_o_n1,
make_tuple(make_unmerge_transform(make_tuple(V_O0, V_O1)),
make_merge_transform_v3_division_mod(make_tuple(N0, N1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return v_grid_desc_o0_n_o1;
}
}; };
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major) // QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
...@@ -1042,17 +1270,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1042,17 +1270,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
template <typename KGridDesc_K0_N_K1_> template <typename KGridDesc_K0_N_K1_>
__device__ static auto MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1) __device__ static auto MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{ {
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0); const auto K0_ = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N_ = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2); const auto K1_ = k_grid_desc_k0_n_k1.GetLength(I2);
constexpr auto K_N1 = B1K1; constexpr auto N1_ = B1K1;
const auto K_N0 = N / K_N1; const auto N0_ = N_ / N1_;
const auto k_grid_desc_n0_k_n1 = transform_tensor_descriptor( const auto k_grid_desc_n0_k_n1 = transform_tensor_descriptor(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(K_N0, K_N1)), make_tuple(make_unmerge_transform(make_tuple(N0_, N1_)),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))), make_merge_transform_v3_division_mod(make_tuple(K0_, K1_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -1084,75 +1312,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1084,75 +1312,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
}; };
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto b2_block_desc_k0_n_k1 = GetB2BlockDescriptor_K0_N_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b2_block_space_size_aligned = math::integer_least_multiple(
b2_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto a2_block_space_offset = 0;
static constexpr auto b2_block_space_offset = a2_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::a2_block_space_size_aligned +
SharedMemTrait::b2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
}
// D0 // D0
static constexpr auto D0M2 = Number<4>{}; static constexpr auto D0M2 = Number<4>{};
static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2; static constexpr auto D0M1 = Number<MPerXdl>{} / D0M2;
...@@ -1185,12 +1344,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1185,12 +1344,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
{ {
using Type = DataType; using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
}; };
template <> template <>
struct TypeTransform<void> struct TypeTransform<void>
{ {
using Type = ck::half_t; using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
}; };
static constexpr index_t NThreadClusterLengths = 32; static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32); static_assert(NPerXdl == 32);
...@@ -1254,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1254,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadCopy = using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1266,6 +1429,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1266,6 +1429,89 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2>; 2>;
}; };
struct SharedMemTrait
{
// LDS allocation for K
static constexpr auto k_block_desc_k0_n_k1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto k_block_space_size_aligned =
math::integer_least_multiple(k_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a2_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
static constexpr auto k_block_space_offset = 0;
static constexpr auto a_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = k_block_space_size_aligned.value;
static constexpr auto a2_block_space_offset = k_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset =
(math::max(a_block_space_size_aligned.value,
b1_block_space_size_aligned.value,
a2_block_space_size_aligned.value) +
k_block_space_size_aligned.value) *
sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm1_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::b1_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm2_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a2_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t gemm3_bytes_end = (SharedMemTrait::k_block_space_size_aligned +
SharedMemTrait::a_block_space_size_aligned) *
sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
gemm2_bytes_end,
gemm3_bytes_end,
softmax_bytes_end,
d0_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -1294,7 +1540,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1294,7 +1540,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3& d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
...@@ -1319,7 +1565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1319,7 +1565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize()); p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize()); p_v_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1327,7 +1573,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1327,7 +1573,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize()); p_vgrad_grid, v_grid_desc_o0_n_o1.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1346,70 +1592,67 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1346,70 +1592,67 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock; const index_t num_gemm0_m_block_outer_loop = q_grid_desc_k0_m_k1.GetLength(I1) / MPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = MPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = MPerBlock / Gemm1KPerBlock;
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim // 6 GEMM operations are categorized into 4 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcc) // dP_MNO Gemm (Gemm0 rcc)
// dV_NOM / dK_NKM Gemm (Gemm1 rrr) // dV_NOM / dK_NKM Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm2 crr) // Y_MON / dQ_MKN Gemm (Gemm2 crr)
// S_MNK Gemm (Gemm3 rcc)
// // LDS allocation for K
// set up S / dP Gemm (type 1 rcc) auto k_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<GemmDataType*>(p_shared) + SharedMemTrait::k_block_space_offset,
GemmBlockwiseCopy::k_block_desc_k0_n_k1.GetElementSpaceSize());
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline // K matrix blockwise copy
// Only supports LoopScheduler::Default auto gemm_tile_k_blockwise_copy =
const auto gemm0_gridwise_gemm_pipeline = typename GemmBlockwiseCopy::template KBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: B matrix blockwise copy
auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
Gemm0::b_block_desc_bk0_n_bk1, GemmBlockwiseCopy::k_block_desc_k0_n_k1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm // Vgpr allocation for V
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC auto v_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<
AddressSpaceEnum::Vgpr,
GemmDataType,
GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
},
Number<GemmBlockwiseCopy::VBlockBufferSize>{});
const auto v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GemmBlockwiseCopy::MakeVGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(v_grid_desc_o0_n_o1);
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
const auto s_gemm_tile_q_block_reset_copy_step = // V matrix blockwise copy
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), -MPerBlock, 0); auto gemm_tile_v_blockwise_copy =
const auto s_gemm_tile_k_block_reset_copy_step = typename GemmBlockwiseCopy::template VBlockwiseCopy<decltype(
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), 0, 0); v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3)>(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index(
0, 0, wave_m_n_id[I0], block_work_idx_n, 0, wave_id[I1], wave_m_n_id[I1], 0));
//
// set up dP Gemm (type 1 rcc)
//
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
// dP: transform input and output tensor descriptors // Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// dP: transform input tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 = const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1); PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP: A matrix blockwise copy // dP: A matrix blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy = auto pgrad_gemm_tile_ygrad_blockwise_copy =
...@@ -1423,30 +1666,47 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1423,30 +1666,47 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// dP: B matrix blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: blockwise gemm // dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer(); auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step = const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), -MPerBlock, 0); make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), -MPerBlock, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), 0, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane( constexpr index_t num_ok_block_main_loop = Gemm1NPerBlock / KPerBlock;
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock); //
// set up S Gemm (type 4 rcc)
//
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm3::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm3::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0,
MPerBlock * (num_gemm0_m_block_outer_loop - 1),
0), // will loop over GemmM dimension
a_element_op,
Gemm3::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm3::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
const auto s_gemm_tile_q_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), -MPerBlock, 0);
// //
// set up dV / dK Gemm (type 2 rrr) // set up dV / dK Gemm (type 2 rrr)
...@@ -1490,7 +1750,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1490,7 +1750,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
auto vgrad_grid_desc_nblock_nperblock_oblock_operblock = auto vgrad_grid_desc_nblock_nperblock_oblock_operblock =
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_n0_o_n1); MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock(v_grid_desc_o0_n_o1);
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 = const auto q_grid_desc_m0_k_m1 =
...@@ -1528,23 +1788,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1528,23 +1788,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// set up dQ Gemm (type 3 crr) // set up dQ Gemm (type 3 crr)
// //
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>; using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_k0_m_k1.GetElementSpaceSize()); Gemm2::a_block_desc_k0_m_k1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset, Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3.GetElementSpaceSize());
Gemm2::b_block_desc_k0_n_k1.GetElementSpaceSize());
// // dQ: transform input and output tensor descriptors
// const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
// Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dQ: transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
// dQ: A matrix VGPR-to-LDS blockwise copy // dQ: A matrix VGPR-to-LDS blockwise copy
auto qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds = auto qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
...@@ -1553,18 +1807,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1553,18 +1807,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
Gemm2::MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2(), Gemm2::MakeAThreadOriginOnBlock_M0_K0_M1_K1_M2_M3_M4_K2(),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// dQ: B matrix global-to-LDS blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
make_multi_index(n_block_data_idx_on_grid / Gemm2Params::B_K1, 0, 0),
tensor_operation::element_wise::PassThrough{},
Gemm2::b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dQ: blockwise gemm // dQ: blockwise gemm
auto qgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{}; auto qgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
qgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
// dQ: B matrix LDS-to-VGPR blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy = typename Gemm2::BBlockwiseCopy{
Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
make_multi_index(0, // nrepeat
k_thread_origin[I1], // nwave
k_thread_origin[I2], // nperxdl
0, // k0
0, // k1
k_thread_origin[I3] / Gemm2Params::GemmKPack, // k2
0)}; // k3
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer(); auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
...@@ -1704,9 +1962,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1704,9 +1962,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ZDataType, ZDataType,
...@@ -1740,6 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1740,6 +1995,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0, // 0, //
wave_m_n_id[I1]), // NPerXdl wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// //
// set up Y dot dY // set up Y dot dY
// //
...@@ -1790,7 +2046,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1790,7 +2046,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{}; auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock); static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 = constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4)); make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
...@@ -1821,17 +2078,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1821,17 +2078,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>( auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize()); y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
...@@ -1841,8 +2090,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1841,8 +2090,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
{
block_sync_lds();
}
// Initialize dK&dV
kgrad_thread_buf.Clear();
vgrad_thread_buf.Clear();
// load k
gemm_tile_k_blockwise_copy.Run(k_grid_desc_k0_n_k1,
k_grid_buf,
GemmBlockwiseCopy::k_block_desc_k0_n_k1,
k_block_buf,
I0);
// load v
static_for<0, GemmBlockwiseCopy::VBlockBufferSize, 1>{}([&](auto ii) {
gemm_tile_v_blockwise_copy.Run(v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
v_grid_buf,
GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_thread_buf(Number<ii>{}));
gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, GemmBlockwiseCopy::v_block_slice_copy_step);
});
do do
{ {
auto m_block_data_idx_on_grid = auto m_block_data_idx_on_grid =
...@@ -1909,22 +2186,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1909,22 +2186,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
lse_thread_buf); lse_thread_buf);
// S = Q * K^T // S = Q * K^T
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>( {
q_grid_desc_k0_m_k1, // preload data into LDS
Gemm0::a_block_desc_ak0_m_ak1, s_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_k0_m_k1, q_grid_buf);
s_gemm_tile_q_blockwise_copy,
q_grid_buf, s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_k0_m_k1,
gemm0_a_block_buf, Gemm3::a_block_slice_copy_step);
Gemm0::a_block_slice_copy_step,
k_grid_desc_k0_n_k1, block_sync_lds(); // wait for previous LDS read
Gemm0::b_block_desc_bk0_n_bk1,
s_gemm_tile_k_blockwise_copy, s_slash_p_thread_buf.Clear();
k_grid_buf,
gemm0_b_block_buf, s_gemm_tile_q_blockwise_copy.RunWrite(Gemm3::a_block_desc_ak0_m_ak1,
Gemm0::b_block_slice_copy_step, gemm3_a_block_buf);
s_blockwise_gemm,
s_slash_p_thread_buf, // main body
num_k_block_main_loop); if constexpr(HasMainKBlockLoop)
{
index_t i = 0;
do
{
s_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_k0_m_k1, q_grid_buf);
block_sync_lds();
s_blockwise_gemm.Run(gemm3_a_block_buf, k_block_buf, s_slash_p_thread_buf);
s_blockwise_gemm.MoveBBlockSrcSliceWindow(Gemm3::b_block_slice_copy_step);
block_sync_lds();
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1, Gemm3::a_block_slice_copy_step);
s_gemm_tile_q_blockwise_copy.RunWrite(Gemm3::a_block_desc_ak0_m_ak1,
gemm3_a_block_buf);
++i;
} while(i < (num_ok_block_main_loop - 1));
}
// tail
{
block_sync_lds();
s_blockwise_gemm.Run(gemm3_a_block_buf, k_block_buf, s_slash_p_thread_buf);
s_blockwise_gemm.MoveBBlockSrcSliceWindow(Gemm3::b_block_slice_copy_step);
}
} // end gemm S
// 8d thread_desc in thread scope // 8d thread_desc in thread scope
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
...@@ -1993,7 +2303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1993,7 +2303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
...@@ -2107,11 +2417,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2107,11 +2417,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV = P_drop^T * dY // dV = P_drop^T * dY
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunRead(), RunWrite(), and MoveSliceWindow(). But it is impossible to // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// implement given that the A1 source buffer is static buffer holding the output // the A1 source buffer is static buffer holding the output of first GEMM and
// of first GEMM and requires constexpr offset by design. Therefore, we pass // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// tensor coordinate offset explicitly in Run() below. // explicitly in Run() below.
// preload data into LDS // preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
...@@ -2173,22 +2483,51 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2173,22 +2483,51 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds(); block_sync_lds();
// dP = dY * V^T // dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same // assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>( {
ygrad_grid_desc_o0_m_o1, // preload data into LDS
Gemm0::a_block_desc_ak0_m_ak1, // reuse pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
pgrad_gemm_tile_ygrad_blockwise_copy, ygrad_grid_buf);
ygrad_grid_buf,
gemm0_a_block_buf, // reuse pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
Gemm0::a_block_slice_copy_step, // reuse ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
v_grid_desc_o0_n_o1,
Gemm0::b_block_desc_bk0_n_bk1, // reuse block_sync_lds(); // wait for previous LDS read
pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf, pgrad_thread_buf.Clear();
gemm0_b_block_buf, // reuse
Gemm0::b_block_slice_copy_step, // reuse pgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm0::a_block_desc_ak0_m_ak1,
pgrad_blockwise_gemm, gemm0_a_block_buf);
pgrad_thread_buf,
num_o_block_main_loop); // main body
if constexpr(num_ok_block_main_loop > 1)
{
static_for<0, num_ok_block_main_loop - 1, 1>{}([&](auto i) {
pgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_o0_m_o1,
ygrad_grid_buf);
block_sync_lds();
pgrad_blockwise_gemm.Run(
gemm0_a_block_buf, v_thread_buf(Number<i>{}), pgrad_thread_buf);
block_sync_lds();
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, Gemm0::a_block_slice_copy_step);
pgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm0::a_block_desc_ak0_m_ak1,
gemm0_a_block_buf);
});
}
// tail
{
block_sync_lds();
pgrad_blockwise_gemm.Run(gemm0_a_block_buf,
v_thread_buf(Number<num_ok_block_main_loop - 1>{}),
pgrad_thread_buf);
}
} // end gemm dP
// dS = P * (dP - Y_dot_dY) // dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf; auto& sgrad_thread_buf = pgrad_thread_buf;
...@@ -2220,9 +2559,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2220,9 +2559,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dQ = scalar * dS * K // dQ = scalar * dS * K
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dQ static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dQ
// load QGrad Gemm B
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
// load QGrad Gemm A // load QGrad Gemm A
const auto sgrad_slice_idx = const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx); Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
...@@ -2245,16 +2581,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2245,16 +2581,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// k slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
// sgrad slice window is moved by loop index k_block_buf,
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm2::b_block_slice_copy_step); make_tuple(I0, I0, I0, I0, I0, I0, I0),
gemm2_b_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
gemm2_b_block_buf); Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before read block_sync_lds(); // sync before read
qgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, qgrad_thread_buf); qgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_thread_buf, qgrad_thread_buf);
}); // end gemm dQ }); // end gemm dQ
// atomic_add dQ // atomic_add dQ
qgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, qgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
...@@ -2267,11 +2604,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2267,11 +2604,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK = scalar * dS^T * Q // dK = scalar * dS^T * Q
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunRead(), RunWrite(), and MoveSliceWindow(). But it is impossible to // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// implement given that the A1 source buffer is static buffer holding the output // the A1 source buffer is static buffer holding the output of first GEMM and
// of first GEMM and requires constexpr offset by design. Therefore, we pass // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// tensor coordinate offset explicitly in Run() below. // explicitly in Run() below.
// preload data into LDS // preload data into LDS
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf); kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
...@@ -2331,17 +2668,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2331,17 +2668,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow( s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
s_gemm_tile_q_block_reset_copy_step); // rewind K and step M s_gemm_tile_q_block_reset_copy_step); // rewind K and step M
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1,
s_gemm_tile_k_block_reset_copy_step); // rewind K
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O and step M pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O and step M
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow( qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_n0_k_n1, Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
Gemm2::b_block_reset_copy_step); // rewind N Gemm2::b_block_reset_copy_step); // rewind N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow( kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1, kgrad_gemm_tile_q_block_next_copy_step); // step M q_grid_desc_m0_k_m1, kgrad_gemm_tile_q_block_next_copy_step); // step M
...@@ -2353,10 +2684,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2353,10 +2684,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(-1, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0));
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0)); make_multi_index(-1, 0, 0, 0));
s_blockwise_gemm.MoveBBlockSrcSliceWindow(Gemm3::b_block_reset_copy_step);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop } while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write // shuffle dK&dV and write
......
...@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -142,8 +142,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
y_grid_desc_mblock_mperblock_nblock_nperblock, y_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDesc_M& d_grid_desc_m, const DGridDesc_M& d_grid_desc_m,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const float p_drop)
{ {
const FloatD p_dropout = type_convert<FloatD>(1.0f - p_drop);
const tensor_operation::element_wise::Scale scale_p_dropout(p_dropout);
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -247,7 +251,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
FloatD, FloatD,
decltype(d_thread_desc_mblock_m1), decltype(d_thread_desc_mblock_m1),
decltype(d_grid_desc_mblock_mperblock), decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::Scale,
Sequence<1, 1>, Sequence<1, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
...@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -258,7 +262,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx_m, // mblock make_multi_index(block_work_idx_m, // mblock
get_thread_local_1d_id()), // mperblock get_thread_local_1d_id()), // mperblock
ck::tensor_operation::element_wise::PassThrough{}}; scale_p_dropout};
// copy from VGPR to Global // copy from VGPR to Global
d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1, d_thread_copy_vgpr_to_global.Run(d_thread_desc_mblock_m1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment