"vscode:/vscode.git/clone" did not exist on "5f0c1a1c0dba914ac0b1e9838b94dd022d512aca"
Commit fced127d authored by danyao12's avatar danyao12
Browse files

only read dO once, reduce data reading from HBM

parent 84f162f9
......@@ -309,13 +309,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr auto ygrad_block_desc_k0_m_k1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// // A matrix in LDS memory, dst of blockwise copy
// static constexpr auto a_block_desc_ak0_m_ak1 =
// GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// // B matrix in LDS memory, dst of blockwise copy
// static constexpr auto b_block_desc_bk0_n_bk1 =
// GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename GridDesc_K0_M_K1>
using QBlockwiseCopy =
......@@ -1002,30 +1002,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
// dY matrix in LDS memory, dst of blockwise copy
static constexpr auto ygrad_block_desc_k0_m_k1 =
static constexpr auto ygrad_block_desc_o0_m_o1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
__host__ __device__ static constexpr auto MakeYGradBlockDesc_M0_K0_M1_K1()
__host__ __device__ static constexpr auto MakeYGradBlockDesc_M_O()
{
const auto K0_ = ygrad_block_desc_k0_m_k1.GetLength(I0);
const auto M_ = ygrad_block_desc_k0_m_k1.GetLength(I1);
const auto K1_ = ygrad_block_desc_k0_m_k1.GetLength(I2);
const auto O0_ = ygrad_block_desc_o0_m_o1.GetLength(I0);
const auto M_ = ygrad_block_desc_o0_m_o1.GetLength(I1);
const auto O1_ = ygrad_block_desc_o0_m_o1.GetLength(I2);
static_assert(O0_ * O1_ == BlockSliceLength_O_, "");
static_assert(M_ == BlockSliceLength_M_, "");
constexpr auto ygrad_block_desc_k_m = transform_tensor_descriptor( //(64, 128)
ygrad_block_desc_k0_m_k1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0_, K1_)), //(8, 8)
return transform_tensor_descriptor( //(128, 64)
ygrad_block_desc_o0_m_o1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(O0_, O1_)), //(8, 8)
make_pass_through_transform(M_)), //128
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<1>{}, Sequence<0>{}));
return transform_tensor_descriptor( //(32, 8, 4, 8)
ygrad_block_desc_k_m,
make_tuple(make_unmerge_transform(make_tuple(ThreadClusterLength_O, ThreadSliceLength_O)), //(8, 8)
make_unmerge_transform(make_tuple(ThreadClusterLength_M, ThreadSliceLength_M))), //(32, 4)
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}));
}
static constexpr auto ygrad_block_desc_m_o = MakeYGradBlockDesc_M_O();
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
......@@ -1127,8 +1126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// static constexpr auto reduction_space_offset = ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value;
static constexpr auto reduction_space_offset = (ygrad_block_space_size_aligned.value + q_block_space_size_aligned.value) * sizeof(DataType) / sizeof(FloatGemmAcc);
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -1543,6 +1541,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto ygrad_thread_desc_m_o = make_naive_tensor_descriptor_packed(make_tuple(
YDotYGrad_M_O::ThreadSliceLength_M, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
......@@ -1552,15 +1552,23 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
constexpr auto ygrad_thread_cluster_desc =
make_cluster_descriptor(Sequence<YDotYGrad_M_O::ThreadClusterLength_M,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1>{});
const auto ygrad_thread_cluster_idx =
ygrad_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto ygrad_thread_data_on_block_idx = ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
......@@ -1574,26 +1582,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
// // performs for ygrad
// auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// YBlockDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
// decltype(y_thread_desc_m0_m1_o0_o1),
// decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
// Sequence<0, 1, 2, 3>,
// 3, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(y_block_desc_mblock_mperblock_oblock_operblock,
// y_thread_data_on_block_idx);
// performs for ygrad
auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()),
Sequence<0, 1>,
1, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_thread_data_on_block_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock);
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset, MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
......@@ -1622,6 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
// load ygrad
gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1, ygrad_grid_buf, GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1, ygrad_block_buf, I0);
block_sync_lds();
//
// calculate Y dot dY
//
......@@ -1630,34 +1643,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0;
do
{
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
});
});
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(0, 0, 1, 0));
y_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
ygrad_threadwise_copy.Run(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_block_buf,
ygrad_thread_desc_m_o,
make_tuple(I0, I0),
ygrad_thread_buf);
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto y_offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
constexpr auto ygrad_offset =
ygrad_thread_desc_m_o.CalculateOffset(make_multi_index(iM, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<y_offset>{}] * ygrad_thread_buf[Number<ygrad_offset>{}];
});
});
// blockwise reduction using atomic_add
block_sync_lds();
......@@ -1691,9 +1697,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// load q
gemm_tile_q_blockwise_copy.Run(q_grid_desc_k0_m_k1, q_grid_buf, GemmBlockwiseCopy::q_block_desc_k0_m_k1, q_block_buf, I0);
// load ygrad
gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1, ygrad_grid_buf, GemmBlockwiseCopy::ygrad_block_desc_k0_m_k1, ygrad_block_buf, I0);
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment