Commit 148fa857 authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop' into mha-train-develop-dropout8bit

parents af695bee 21ef37b4
......@@ -544,16 +544,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
};
// dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
......@@ -580,7 +577,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -1296,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1523,6 +1520,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
......@@ -1857,7 +1856,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
......
......@@ -547,16 +547,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
......@@ -584,7 +581,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -847,7 +844,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2
{
private:
......@@ -870,9 +867,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
......@@ -969,12 +963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2);
const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1,
BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
......@@ -1120,16 +1114,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
};
// S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
......@@ -1183,9 +1174,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1381,7 +1372,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1594,6 +1585,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1630,6 +1623,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// set up S Gemm (type 4 rcc)
//
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1735,7 +1730,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......@@ -1980,7 +1977,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
......
......@@ -565,16 +565,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
// dP Gemm (type 1 rcc)
template <typename BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2 =
GetVThreadDescriptor_K0_K1_N0_N1_N2_N3_K2();
template <typename BThreadDesc_K0_K1_N0_N1_N2_N3_K2>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_N0_N1_N2_N3_K2& b_thread_desc_k0_k1_n0_n1_n2_n3_k2)
......@@ -601,7 +598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_n0_n1_n2_n3_k2);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_N0_N1_N2_N3_K2{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -1364,7 +1361,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1606,6 +1603,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_n0_n1_n2_n3_k2)>;
// dP: blockwise gemm
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
pgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
......@@ -2018,7 +2017,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
......
......@@ -568,16 +568,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// dP Gemm (type 1 rcc, B in Vgpr)
template <typename BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B source matrix layout in VGPR
static constexpr auto b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GetVThreadDescriptor_K0_K1_K2_N0_N1_N2_N3_K3();
template <typename BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3>
__host__ __device__ static constexpr auto GetBThreadDescriptor_K0_N_K1(
const BThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)
......@@ -605,7 +602,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
static constexpr auto b_src_thread_desc_k0_n_k1 =
GetBThreadDescriptor_K0_N_K1(b_src_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3);
GetBThreadDescriptor_K0_N_K1(BSrcThreadDesc_K0_K1_K2_N0_N1_N2_N3_K3{});
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
......@@ -868,7 +865,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// dQ Gemm (type 3 crr)
template <typename Gemm2Params, typename ASrcBlockwiseGemm>
template <typename Gemm2Params, typename ASrcBlockwiseGemm, typename BSrcBlockDesc_N0_K_N1>
struct Gemm2
{
private:
......@@ -891,9 +888,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_k0_m_k1 = GetA2BlockDescriptor_K0_M_K1<Gemm2Params>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_n0_k_n1 = GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_K0_M_K1>
__host__ __device__ static constexpr auto
MakeGemm2AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_K0_M_K1&)
......@@ -990,12 +984,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto MakeBBlockDesc_N0_N1_N2_K0_K1_K2_K3()
{
const auto N0_ = b_block_desc_n0_k_n1.GetLength(I0);
const auto K_ = b_block_desc_n0_k_n1.GetLength(I1);
const auto N1_ = b_block_desc_n0_k_n1.GetLength(I2);
const auto N0_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I0);
const auto K_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I1);
const auto N1_ = BSrcBlockDesc_N0_K_N1{}.GetLength(I2);
constexpr auto b_block_desc_n_k = transform_tensor_descriptor( //(32, 128) //(64, 128)
b_block_desc_n0_k_n1,
BSrcBlockDesc_N0_K_N1{},
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(N0_, N1_)), //(4, 8) //(8, 8)
make_pass_through_transform(K_)), // 128 // 128
......@@ -1141,16 +1135,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
// S Gemm (type 4 rcc, B in LDS)
template <typename BSrcBlockDesc_K0_N_K1>
struct Gemm3
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetKBlockDescriptor_K0PerBlock_NPerBlock_K1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
......@@ -1204,9 +1195,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GemmDataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
BSrcBlockDesc_K0_N_K1,
decltype(MakeGemm3AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
decltype(MakeGemm3BMmaTileDescriptor_N0_N1_N2_K(BSrcBlockDesc_K0_N_K1{})),
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -1426,7 +1417,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord
1>;
using D0ThreadCopy =
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
......@@ -1652,6 +1643,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up dP Gemm (type 1 rcc)
//
using Gemm0 = Gemm0<decltype(GemmBlockwiseCopy::v_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3)>;
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1688,6 +1681,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// set up S Gemm (type 4 rcc)
//
using Gemm3 = Gemm3<decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm3: LDS allocation for A and B: be careful of alignment
auto gemm3_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
......@@ -1793,7 +1788,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
//
// set up dQ Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params, decltype(pgrad_blockwise_gemm)>;
using Gemm2 = Gemm2<Gemm2Params,
decltype(pgrad_blockwise_gemm),
decltype(GemmBlockwiseCopy::k_block_desc_k0_n_k1)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
......@@ -2093,7 +2090,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
if constexpr(Deterministic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment