Commit 118742b6 authored by ltqin's avatar ltqin
Browse files

rewrite code

parent 33fad9ba
......@@ -777,18 +777,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
YGridDesc_M_O,
ORSGridDesc_M,
BlockSize,
MPerBlock,
NPerBlock,
128,
128,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
32,
32,
1,
4,
ABlockLdsExtraM,
BBlockLdsExtraN,
Deterministic>;
......
......@@ -793,18 +793,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YGridDesc_M_O,
ORSGridDesc_M,
BlockSize,
MPerBlock,
NPerBlock,
256,
128,
KPerBlock,
Gemm1NPerBlock,
32,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
64,
64,
1,
4,
ABlockLdsExtraM,
BBlockLdsExtraN,
Deterministic>;
......
......@@ -127,14 +127,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{}));
make_tuple(Sequence<0, 1>{}));
return lse_grid_desc_mblock_mrepeat_mwave_mperxdl;
}
......@@ -214,13 +212,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static_assert(BlockSize_ == BlockSliceLength_M_);
static constexpr auto ThreadSliceLength_M = Number<1>{};
static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_O = Number<1>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
static constexpr auto ThreadSliceLength_O = Number<BlockSliceLength_O_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
......@@ -274,32 +271,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
//
// set up Y dot dY
//
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
constexpr auto thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto ors_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2));
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
make_naive_tensor_descriptor_packed(make_tuple(I1, I1));
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
......@@ -311,6 +288,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
// if(get_thread_global_1d_id() == 1)
// {
// printf("y_thread_data_on_block_idx:{ %d, %d, %d,%d}, get_thread_local_1d_id: %d\n",
// y_thread_data_on_block_idx[I0],
// y_thread_data_on_block_idx[I1],
// y_thread_data_on_block_idx[I2],
// y_thread_data_on_block_idx[I3],
// get_thread_local_1d_id());
// }
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
......@@ -337,42 +324,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1));
constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl =
ors_thread_desc_mblock_mrepeat_mwave_mperxdl; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
if constexpr(Deterministic)
{
block_sync_lds();
}
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
......@@ -406,23 +362,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
// blockwise reduction using atomic_add
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
});
block_sync_lds();
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
auto ors_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(ors_grid_desc_m);
......@@ -432,40 +371,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
decltype(ors_thread_desc_mblock_mrepeat_mwave_mperxdl),
decltype(ors_grid_desc_mblock_mrepeat_mwave_mperxdl),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1>,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{ors_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock
0, // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4]), // mperxdl
get_thread_local_1d_id()), // mperxdl
ck::tensor_operation::element_wise::PassThrough{}};
if(get_warp_local_1d_id() < 32)
{
static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global
ors_thread_copy_vgpr_to_global.Run(ors_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, Number<I>{}, I0, I0),
y_dot_ygrad_thread_buf,
make_tuple(I0, I0),
y_dot_ygrad_thread_accum_buf,
ors_grid_desc_mblock_mrepeat_mwave_mperxdl,
ors_grid_buf);
ors_thread_copy_vgpr_to_global.MoveDstSliceWindow(
ors_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0));
});
}
// if(get_warp_local_1d_id() == 0)
// {
// printf(
// "acc0_thread_origin[I0]:%d acc0_thread_origin[I2]: %d acc0_thread_origin[I4]:
// %d\t", acc0_thread_origin[I0], acc0_thread_origin[I2], acc0_thread_origin[I4]);
// }
ignore = ors_thread_copy_vgpr_to_global;
ignore = ors_grid_desc_mblock_mrepeat_mwave_mperxdl;
}
};
......
......@@ -25,4 +25,6 @@ __device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; }
__device__ index_t get_thread_local_1d_id_in_warp() { return threadIdx.x % get_warp_size(); }
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment