Unverified Commit acea1753 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1026 from ROCmSoftwarePlatform/mha-train-develop-fix-d0vgpr2lds

fix flash attention bwd output d0 grad
parents 620eeae8 22593950
......@@ -1237,6 +1237,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NXdlPerWave == 1);
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
......@@ -1314,11 +1315,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
......@@ -1918,7 +1919,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
......@@ -2205,6 +2206,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
if(p_d0grad_grid != nullptr)
{
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......@@ -2212,11 +2216,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
make_tuple(I0, I0, I0, I0, I0),
d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
......
......@@ -1316,7 +1316,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NXdlPerWave == 1);
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
......@@ -1394,11 +1396,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
......@@ -2040,7 +2042,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
......@@ -2473,6 +2475,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
if(p_d0grad_grid != nullptr)
{
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......@@ -2480,12 +2485,23 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
......
......@@ -1304,7 +1304,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NXdlPerWave == 1);
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
......@@ -1382,11 +1384,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
......@@ -2080,7 +2082,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
......@@ -2406,6 +2408,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
if(p_d0grad_grid != nullptr)
{
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......@@ -2413,11 +2418,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
make_tuple(I0, I0, I0, I0, I0),
d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
......
......@@ -1370,7 +1370,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NXdlPerWave == 1);
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
......@@ -1448,11 +1450,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
......@@ -2162,7 +2164,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
......@@ -2634,6 +2636,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
if(p_d0grad_grid != nullptr)
{
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......@@ -2641,11 +2646,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
make_tuple(I0, I0, I0, I0, I0),
d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment