Commit 7aa37568 authored by danyao12's avatar danyao12
Browse files

qloop dropout optimize

parent 4274096b
......@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8,
true,
4,
S<8, 32, 1>, // B1BlockTransfer
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
......
......@@ -121,6 +121,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -194,6 +195,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -257,7 +259,7 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -266,7 +268,8 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm1NXdlPerWave
1, // DropoutStep
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -282,7 +285,7 @@ using DeviceGemmInstance =
8,
true,
1,
S<8, 32, 1>, // B1BlockTransfer
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
......
......@@ -138,12 +138,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw);
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -179,12 +179,12 @@ struct BlockwiseDropout
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{} * MRaw);
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{} * MRaw);
}
block_sync_lds();
......@@ -218,21 +218,19 @@ struct BlockwiseDropout
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer,
typename Step,
typename Offset> // N3*N4=8
template <typename ZThreadBuffer, typename Step, typename Offset>
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{});
ph.get_random_8x16((tmp + i * 8), element_global_1d_id + i * Offset{});
}
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
......
......@@ -40,7 +40,7 @@ template <typename GridwiseGemm,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
typename LSEGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
......@@ -73,8 +73,8 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
......@@ -141,7 +141,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
......@@ -174,7 +174,7 @@ __global__ void
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
......@@ -203,7 +203,7 @@ __global__ void
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6;
ignore = lse_grid_desc_m;
ignore = block_2_ctile_map;
ignore = batch_count;
......@@ -263,6 +263,7 @@ template <index_t NumDimG,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -564,6 +565,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......@@ -735,8 +737,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
z_grid_desc_m_n_);
m_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[0]);
n_raw_padded_ = GridwiseGemm::GetPaddedSize(raw_lengths_mz_nz_kz_gemm1nz_[1]);
......@@ -791,8 +794,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -876,7 +879,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6,
DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
......@@ -909,7 +912,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg.lse_grid_desc_m_,
arg.block_2_ctile_map_,
arg.batch_count_,
......
......@@ -135,7 +135,7 @@ __global__ void
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
......@@ -173,7 +173,7 @@ __global__ void
arg_ptr[group_id].d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
......@@ -244,6 +244,7 @@ template <index_t NumDimG,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t DropoutStep,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -566,6 +567,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
DropoutStep,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......@@ -622,8 +624,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
......@@ -768,12 +770,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5_N6(
z_grid_desc_m_n);
const index_t BlockStart = grid_size_;
......@@ -829,7 +827,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_n6,
z_grid_desc_m_n,
lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......
......@@ -1533,8 +1533,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
......@@ -1966,16 +1966,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
s_slash_p_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
......@@ -1473,8 +1473,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
......@@ -1865,16 +1865,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
s_slash_p_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
......@@ -110,6 +110,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy
......@@ -119,10 +124,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -136,9 +140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -542,9 +544,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -646,8 +646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
......@@ -770,9 +769,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack =
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
......@@ -1570,8 +1568,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1759,7 +1757,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1774,23 +1771,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1806,15 +1807,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -121,6 +121,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
......@@ -133,10 +138,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -150,9 +154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -522,9 +524,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -657,8 +657,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -709,9 +708,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack =
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths =
......@@ -1554,8 +1551,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1722,7 +1719,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1737,23 +1733,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1769,14 +1769,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -109,6 +109,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// C desc for source in blockwise copy
......@@ -118,10 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -135,9 +139,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -563,9 +565,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockDesc_BK0_N_BK1{});
}
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -667,8 +667,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
static constexpr index_t GemmMWave = Gemm0NWaves; // 4 // 4
static constexpr index_t GemmNWave = Gemm0MWaves; // 1 // 1
......@@ -791,9 +790,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave; // 1 // 1
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl; // 1 // 1
static constexpr index_t GemmKLoop = Gemm2_K / Sum_K; // 2 // 2
static constexpr index_t GemmKPack =
math::max(A_K1, MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t GemmKPack = math::max(A_K1, mfma.k_per_blk);
static constexpr index_t B_K3 = GemmKPack; // 8
static constexpr index_t B_K2 =
XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}.K0PerXdlops; // 2
static constexpr index_t B_K1 = Sum_K / B_K2 / B_K3; // 4
......@@ -1621,8 +1619,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1946,7 +1944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1961,23 +1958,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1993,15 +1994,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -120,6 +120,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto DropoutNThread = mfma.num_input_blks; // 2
// get_random_8x16() generates 8 random numbers each time
static constexpr auto DropoutTile = Number<DropoutNThread * 8>{}; // 16
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
......@@ -132,10 +137,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
......@@ -149,9 +153,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto group_size = mfma.group_size;
return math::integer_divide_ceil(size, group_size) * group_size;
return math::integer_divide_ceil(size, DropoutTile) * DropoutTile;
}
__device__ static auto GetGemm0WaveIdx()
......@@ -543,9 +545,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(math::lcm(AK1, BK1), mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
......@@ -678,8 +678,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
static constexpr index_t GemmKPack = mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
......@@ -730,9 +729,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr index_t GemmMWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmMRepeat = Gemm2_M / GemmMWave / MPerXdl;
static constexpr index_t GemmKPack =
math::max(math::lcm(A_K1, B_K1),
MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
static constexpr index_t GemmKPack = math::max(math::lcm(A_K1, B_K1), mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_K0, Gemm2_N, B_K1>;
using BThreadClusterLengths =
......@@ -1582,8 +1579,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ushort,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
......@@ -1862,7 +1859,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
constexpr auto position_offset = M3 * M4;
// save z to global
if constexpr(IsDropout)
{
......@@ -1877,23 +1873,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
blockwise_dropout
.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
decltype(position_offset),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, raw_n_padded);
decltype(z_tensor_buffer),
decltype(DropoutTile),
true>(s_slash_p_thread_buf,
ph,
global_elem_id,
z_tensor_buffer,
raw_n_padded);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
......@@ -1909,14 +1909,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id_raw = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
auto global_tile_id = z_random_matrix_offset +
(m_global / DropoutTile) * DropoutTile * raw_n_padded +
(n_global / DropoutTile) * DropoutTile;
auto global_elem_id = global_tile_id + (wave_m_n_id[I0] * M4) +
(n_global % DropoutTile) * raw_n_padded;
auto global_elem_id =
(global_elem_id_raw % M4) * raw_n_padded + (global_elem_id_raw / M4) * M4;
// P_dropped
blockwise_dropout.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf),
decltype(position_offset),
decltype(DropoutTile),
true>(
s_slash_p_thread_buf, ph, global_elem_id, raw_n_padded);
}
......
......@@ -873,8 +873,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
z_tensor_buffer;
z_tensor_buffer.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -1022,16 +1022,16 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
decltype(z_tensor_buffer),
false,
decltype(n0),
decltype(i)>(
acc_thread_buf, ph, z_tenor_buffer);
acc_thread_buf, ph, z_tensor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment