Commit 5a032110 authored by letaoqin's avatar letaoqin
Browse files

fix d0 write data form vgpr to lds

parent d7544ea0
...@@ -1314,11 +1314,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1314,11 +1314,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
2>; 2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim 4, // VectorDim
...@@ -1918,7 +1918,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1918,7 +1918,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds( auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
...@@ -2205,6 +2205,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2205,6 +2205,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
...@@ -2212,11 +2215,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2212,11 +2215,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
sgrad_thread_buf, d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
......
...@@ -1394,11 +1394,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1394,11 +1394,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
2>; 2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim 4, // VectorDim
...@@ -2040,7 +2040,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2040,7 +2040,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds( auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
...@@ -2473,6 +2473,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2473,6 +2473,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
...@@ -2480,12 +2483,23 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2480,12 +2483,23 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
sgrad_thread_buf, d0grad_thread_buf,
D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
block_sync_lds(); block_sync_lds();
......
...@@ -1382,11 +1382,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1382,11 +1382,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
2>; 2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim 4, // VectorDim
...@@ -2080,7 +2080,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2080,7 +2080,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds( auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
...@@ -2406,6 +2406,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2406,6 +2406,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
...@@ -2413,11 +2416,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2413,11 +2416,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
sgrad_thread_buf, d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
......
...@@ -2636,6 +2636,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2636,6 +2636,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static constexpr auto& sgrad_thread_desc = static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc(); pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment