Commit d7544ea0 authored by letaoqin's avatar letaoqin
Browse files

fx d0 write data form vgpr to lds

parent 620eeae8
......@@ -1448,11 +1448,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
tensor_operation::element_wise::PassThrough, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
......@@ -2162,7 +2162,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
tensor_operation::element_wise::PassThrough{});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
......@@ -2634,6 +2634,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
if(p_d0grad_grid != nullptr)
{
static constexpr auto& sgrad_thread_desc =
pgrad_blockwise_gemm.GetCThreadDesc();
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
......@@ -2641,11 +2643,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
static_for<0, d0grad_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
sgrad_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
d0grad_thread_buf(i) = ck::type_convert<D0DataType>(
rp_dropout * sgrad_thread_buf(Number<c_offset>{}));
});
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
make_tuple(I0, I0, I0, I0, I0),
d0grad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment