Commit e0f595de authored by letaoqin's avatar letaoqin
Browse files

kernel save grad bias

parent 042761af
......@@ -1436,16 +1436,49 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
4, // SrcVectorDim
4, // SrcScalarPerVector
2>;
using D0ThreadCopyVgprToBlock =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_thread_desc_), // SrcDesc
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // DstDesc
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
2>;
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
......@@ -1574,7 +1607,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t raw_n_padded,
const index_t block_idx_n)
{
ignore = p_d0grad_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits =
......@@ -2109,10 +2141,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToBlock(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
ignore = d0_thread_copy_vgpr_to_lds;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
ignore = d0grad_thread_copy_vgpr_to_lds;
if constexpr(Deterministic)
{
block_sync_lds();
......@@ -2563,6 +2605,41 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
(undropped_flag ? (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}])
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment