Commit db579ac9 authored by danyao12's avatar danyao12
Browse files

add templates for bwd gridwise gemm

parent 9e527364
...@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}; };
// dP Gemm (type 1 rcc) // dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0 struct Gemm0
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2> template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1( __host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2) const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
...@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
} }
static constexpr auto b_src_thread_desc_k0_n_k1 = static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2); GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadCopy = using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// set up dP Gemm (type 1 rcc) // set up dP Gemm (type 1 rcc)
// //
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm // dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0)); pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
...@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic) if constexpr(Deterministic)
......
...@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}; };
// dP Gemm (type 1 rcc, B in Vgpr) // dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0 struct Gemm0
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3> template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1( __host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3) const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
...@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
} }
static constexpr auto b_src_thread_desc_k0_n_k1 = static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3); GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}; };
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm> template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2 struct Gemm2
{ {
private: private:
...@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>(); static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_K0_M_K1> template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&) MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
...@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3() __host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{ {
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0); const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1); const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2); const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128) constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1, BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod( make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8) make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128 make_pass_through_transform(K_)), // 128 // 128
...@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}; };
// S Gemm (type 4 rcc, B in LDS) // S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3 struct Gemm3
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
...@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
GemmDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadCopy = using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up dP Gemm (type 1 rcc) // set up dP Gemm (type 1 rcc)
// //
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment // Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
...@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up S Gemm (type 4 rcc) // set up S Gemm (type 4 rcc)
// //
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment // Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
...@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// //
// set up dQ Gemm (type 3 crr) // set up dQ Gemm (type 3 crr)
// //
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>; using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
...@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic) if constexpr(Deterministic)
......
...@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}; };
// dP Gemm (type 1 rcc) // dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0 struct Gemm0
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2> template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1( __host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2) const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
...@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
static constexpr auto b_src_thread_desc_k0_n_k1 = static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2); GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadCopy = using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1606,6 +1603,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1606,6 +1603,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// set up dP Gemm (type 1 rcc) // set up dP Gemm (type 1 rcc)
// //
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm // dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0)); pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
...@@ -2018,7 +2017,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2018,7 +2017,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic) if constexpr(Deterministic)
......
...@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}; };
// dP Gemm (type 1 rcc, B in Vgpr) // dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0 struct Gemm0
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3> template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1( __host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3) const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
...@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
static constexpr auto b_src_thread_desc_k0_n_k1 = static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3); GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}; };
// dQ Gemm (type 3 crr) // dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm> template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2 struct Gemm2
{ {
private: private:
...@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>(); static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_K0_M_K1> template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&) MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
...@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3() __host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{ {
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0); const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1); const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2); const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128) constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1, BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod( make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8) make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128 make_pass_through_transform(K_)), // 128 // 128
...@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}; };
// S Gemm (type 4 rcc, B in LDS) // S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3 struct Gemm3
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 = static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
...@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadCopy = using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
...@@ -1652,6 +1643,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1652,6 +1643,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up dP Gemm (type 1 rcc) // set up dP Gemm (type 1 rcc)
// //
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment // Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
...@@ -1688,6 +1681,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1688,6 +1681,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up S Gemm (type 4 rcc) // set up S Gemm (type 4 rcc)
// //
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment // Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
...@@ -1793,7 +1788,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1793,7 +1788,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// //
// set up dQ Gemm (type 3 crr) // set up dQ Gemm (type 3 crr)
// //
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>; using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
...@@ -2093,7 +2090,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2093,7 +2090,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic) if constexpr(Deterministic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment