Commit 6b814bac authored by danyao12's avatar danyao12
Browse files

code cleanup for light kernels

parent 73f0c21b
......@@ -97,7 +97,6 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
typename DGridDescriptor_M,
typename YGradGridDesc_O0_M_O1,
......@@ -131,8 +130,6 @@ __global__ void
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m,
const DGridDescriptor_M d_grid_desc_m,
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
......@@ -198,7 +195,6 @@ __global__ void
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
......@@ -233,7 +229,6 @@ __global__ void
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_o0_m_o1,
......@@ -258,7 +253,6 @@ __global__ void
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
......@@ -842,7 +836,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
......@@ -882,16 +875,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
y_grid_desc_m_o_))
{
y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o_);
}
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -961,8 +944,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
......@@ -1081,7 +1062,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1,
......@@ -1116,7 +1096,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.d_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_,
......
......@@ -96,7 +96,6 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
typename DGridDescriptor_M,
typename YGradGridDesc_M0_O_M1,
......@@ -130,8 +129,6 @@ __global__ void
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m,
const DGridDescriptor_M d_grid_desc_m,
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
......@@ -197,7 +194,6 @@ __global__ void
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_m0_o_m1,
......@@ -232,7 +228,6 @@ __global__ void
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
d_grid_desc_m,
ygrad_grid_desc_m0_o_m1,
......@@ -257,7 +252,6 @@ __global__ void
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
......@@ -855,7 +849,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
d_y_grid_desc_mblock_mperblock_oblock_operblock_{},
......@@ -895,16 +888,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
y_grid_desc_m_o_))
{
y_grid_desc_mblock_mperblock_oblock_operblock_ =
GridwiseGemm::MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(
y_grid_desc_m_o_);
}
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -974,8 +957,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
......@@ -1098,7 +1079,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::DGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1,
......@@ -1133,7 +1113,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
arg.d_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_,
......
......@@ -283,25 +283,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
return true;
}
__host__ __device__ static constexpr auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(const YGridDesc_M_O& y_grid_desc_m_o)
{
const auto M = y_grid_desc_m_o.GetLength(I0);
const auto O = y_grid_desc_m_o.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
y_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
template <typename SrcBlockwiseGemm>
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4(const LSEGridDesc_M& lse_grid_desc_m)
......@@ -400,9 +381,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
k_grid_desc_n_k);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(YGridDesc_M_O{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
......@@ -1107,53 +1085,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
true>;
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// dY matrix in LDS memory, dst of blockwise copy
static constexpr auto ygrad_block_desc_o0_m_o1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
__host__ __device__ static constexpr auto MakeYGradBlockDesc_M_O()
{
const auto O0_ = ygrad_block_desc_o0_m_o1.GetLength(I0);
const auto M_ = ygrad_block_desc_o0_m_o1.GetLength(I1);
const auto O1_ = ygrad_block_desc_o0_m_o1.GetLength(I2);
static_assert(O0_ * O1_ == BlockSliceLength_O_, "");
static_assert(M_ == BlockSliceLength_M_, "");
return transform_tensor_descriptor( //(128, 64)
ygrad_block_desc_o0_m_o1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(O0_, O1_)), //(8, 8)
make_pass_through_transform(M_)), // 128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
}
static constexpr auto ygrad_block_desc_m_o = MakeYGradBlockDesc_M_O();
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
struct SharedMemTrait
{
// // LDS allocation for A and B: be careful of alignment
......@@ -1248,8 +1179,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
......@@ -1666,101 +1595,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_M3 = p_block_lengths[I5];
constexpr auto P_M4 = p_block_lengths[I6];
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto ygrad_thread_desc_m_o = make_naive_tensor_descriptor_packed(
make_tuple(YDotYGrad_M_O::ThreadSliceLength_M, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
constexpr auto ygrad_thread_cluster_desc = make_cluster_descriptor(
Sequence<YDotYGrad_M_O::ThreadClusterLength_M, YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1>{});
const auto ygrad_thread_cluster_idx = ygrad_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto ygrad_thread_data_on_block_idx =
ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(num_gemm0_m_block_outer_loop - 1, I0, I0, I0) +
y_thread_data_on_block_idx;
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
// performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
GemmDataType,
FloatGemmAcc,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()),
Sequence<0, 1>,
1, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_thread_data_on_block_idx);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
constexpr auto y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4 =
lse_thread_desc_mb_m0_m1_m2_m3_m4; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4),
decltype(y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
......@@ -1809,7 +1648,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
block_sync_lds();
//
// calculate Y dot dY
// load d and lse
//
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
......@@ -2071,6 +1910,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2082,8 +1922,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
gemm2_a_block_buf);
}
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......@@ -2127,9 +1965,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0));
} while(0 < gemm0_m_block_outer_index--); // end j loop
// shuffle dK&dV and write
......
......@@ -338,25 +338,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(const YGridDesc_M_O& y_grid_desc_m_o)
{
const auto M = y_grid_desc_m_o.GetLength(I0);
const auto O = y_grid_desc_m_o.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto OBlock = O / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
y_grid_desc_m_o,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(OBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
template <typename SrcBlockwiseGemm>
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MB_M0_M1_M2_M3_M4(const LSEGridDesc_M& lse_grid_desc_m)
......@@ -455,9 +436,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
k_grid_desc_n_k);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock(YGridDesc_M_O{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
......@@ -966,30 +944,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
true>;
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
......@@ -1180,8 +1134,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const DGridDesc_M& d_grid_desc_m,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
......@@ -1627,78 +1579,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
wave_m_n_id[I1]), // NPerXdl
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto P_M3 = p_block_lengths[I5];
constexpr auto P_M4 = p_block_lengths[I6];
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(num_gemm0_m_block_outer_loop - 1, I0, I0, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
constexpr auto y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4 =
make_naive_tensor_descriptor_packed(make_tuple(I1, P_M0, P_M1, P_M2, P_M3, P_M4));
constexpr auto y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4 =
lse_thread_desc_mb_m0_m1_m2_m3_m4; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4),
decltype(y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4),
Sequence<1, m0, m1, m2, m3, m4>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
1,
1,
true /* ResetCoordAfterRun */>{
y_dot_ygrad_block_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4], // mperxdl
acc0_thread_origin[I5],
acc0_thread_origin[I6])};
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mb_m0_m1_m2_m3_m4.GetElementSpaceSize());
......@@ -1724,7 +1609,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
//
// calculate Y dot dY
// load d and lse
//
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mb_m0_m1_m2_m3_m4,
......@@ -2002,6 +1887,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2018,7 +1904,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf);
......@@ -2120,9 +2005,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(-1, 0, 0, 0, 0, 0));
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(-1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment