Commit d7544ea0 authored by letaoqin's avatar letaoqin
Browse files

fx d0 write data form vgpr to lds

parent 620eeae8
...@@ -1448,17 +1448,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1448,17 +1448,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2>; 2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim 4, // VectorDim
4, // ScalarPerVector 4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1< using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
...@@ -2162,7 +2162,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2162,7 +2162,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds( auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
...@@ -2634,6 +2634,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2634,6 +2634,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
...@@ -2641,11 +2643,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2641,11 +2643,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
sgrad_thread_buf, d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment