Commit c88d1173 authored by letaoqin's avatar letaoqin
Browse files

change d0 operator variables name

parent 12c0f86a
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -91,7 +91,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -91,7 +91,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert(KPerBlock == Gemm1NPerBlock); static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1257,14 +1257,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1257,14 +1257,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1276,18 +1281,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1276,18 +1281,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1295,21 +1300,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1295,21 +1300,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1>; 1>;
using D0ThreadwiseCopyLdsToVgpr = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
using D0ThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
...@@ -1319,7 +1324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1319,7 +1324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1< using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1330,18 +1335,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1330,18 +1335,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim 5, // SrcVectorDim
4, // DstVectorDim 4, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1381,8 +1386,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1381,8 +1386,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned.value; q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value + (k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) * q_block_space_size_aligned.value) *
...@@ -1898,23 +1902,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1898,23 +1902,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadwiseCopyVgprToLds( auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -2062,7 +2067,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2062,7 +2067,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
...@@ -2076,16 +2081,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2076,16 +2081,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run( d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_block_buf,
d0_block_buf, D0Operator::d0_thread_desc_,
D0Operator::d0_thread_desc_, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_thread_buf);
d0_thread_buf);
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
...@@ -2197,36 +2201,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2197,36 +2201,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf, sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
block_sync_lds(); block_sync_lds();
// write data from lds to global // write data from lds to global
d0_block_copy_lds_to_global.Run( d0grad_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf, d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf, d0grad_grid_buf,
I0); I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
}); });
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
} }
......
...@@ -1336,14 +1336,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1336,14 +1336,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1355,18 +1360,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1355,18 +1360,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1374,21 +1379,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1374,21 +1379,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1>; 1>;
using D0ThreadwiseCopyLdsToVgpr = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
using D0ThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
...@@ -1398,7 +1403,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1398,7 +1403,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1< using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1409,18 +1414,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1409,18 +1414,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim 5, // SrcVectorDim
4, // DstVectorDim 4, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1460,8 +1465,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1460,8 +1465,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math::integer_least_multiple(BlockSize, max_lds_align); math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) / k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Operator::template TypeTransform<D0DataType>::Size; D0Operator::template TypeTransform<D0DataType>::Size;
...@@ -2019,23 +2023,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2019,23 +2023,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadwiseCopyVgprToLds( auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -2213,7 +2218,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2213,7 +2218,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
...@@ -2227,16 +2232,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2227,16 +2232,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run( d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_block_buf,
d0_block_buf, D0Operator::d0_thread_desc_,
D0Operator::d0_thread_desc_, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_thread_buf);
d0_thread_buf);
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
...@@ -2464,36 +2468,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2464,36 +2468,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf, sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
block_sync_lds(); block_sync_lds();
// write data from lds to global // write data from lds to global
d0_block_copy_lds_to_global.Run( d0grad_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf, d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf, d0grad_grid_buf,
I0); I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
}); });
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
} }
......
...@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert(KPerBlock == Gemm1NPerBlock); static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1325,14 +1325,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1325,14 +1325,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1344,18 +1349,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1344,18 +1349,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1363,21 +1368,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1363,21 +1368,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1>; 1>;
using D0ThreadwiseCopyLdsToVgpr = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
using D0ThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
...@@ -1387,7 +1392,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1387,7 +1392,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1< using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1398,18 +1403,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1398,18 +1403,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim 5, // SrcVectorDim
4, // DstVectorDim 4, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1458,8 +1463,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1458,8 +1463,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
sizeof(GemmDataType) / sizeof(FloatGemmAcc); sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value + (k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) * q_block_space_size_aligned.value) *
...@@ -2060,23 +2064,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2060,23 +2064,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadwiseCopyVgprToLds( auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -2263,7 +2268,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2263,7 +2268,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
...@@ -2277,16 +2282,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2277,16 +2282,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run( d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_block_buf,
d0_block_buf, D0Operator::d0_thread_desc_,
D0Operator::d0_thread_desc_, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_thread_buf);
d0_thread_buf);
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
...@@ -2398,36 +2402,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2398,36 +2402,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf, sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
block_sync_lds(); block_sync_lds();
// write data from lds to global // write data from lds to global
d0_block_copy_lds_to_global.Run( d0grad_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf, d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf, d0grad_grid_buf,
I0); I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
}); });
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
} }
......
...@@ -1381,14 +1381,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1381,14 +1381,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1400,18 +1405,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1400,18 +1405,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1419,21 +1424,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1419,21 +1424,21 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1>; 1>;
using D0ThreadwiseCopyLdsToVgpr = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
using D0ThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3< using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
typename TypeTransform<D0DataType>::Type, typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_), decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder Sequence<0, 1, 2, 3, 4>, // AccessOrder
...@@ -1443,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1443,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1< using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1454,18 +1459,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1454,18 +1459,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim 5, // SrcVectorDim
4, // DstVectorDim 4, // DstVectorDim
4, // SrcScalarPerVector 4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1512,8 +1517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1512,8 +1517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
sizeof(GemmDataType) / sizeof(FloatGemmAcc); sizeof(GemmDataType) / sizeof(FloatGemmAcc);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) / k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Operator::template TypeTransform<D0DataType>::Size; D0Operator::template TypeTransform<D0DataType>::Size;
...@@ -2132,23 +2136,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2132,23 +2136,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadwiseCopyVgprToLds( auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0), make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout}); tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal( auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
...@@ -2365,7 +2370,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2365,7 +2370,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
...@@ -2379,16 +2384,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2379,16 +2384,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run( d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_block_buf,
d0_block_buf, D0Operator::d0_thread_desc_,
D0Operator::d0_thread_desc_, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_thread_buf);
d0_thread_buf);
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
...@@ -2616,36 +2620,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2616,36 +2620,36 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if(p_d0grad_grid != nullptr) if(p_d0grad_grid != nullptr)
{ {
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run( d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0), make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf, sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf); d0grad_block_buf);
block_sync_lds(); block_sync_lds();
// write data from lds to global // write data from lds to global
d0_block_copy_lds_to_global.Run( d0grad_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf, d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf, d0grad_grid_buf,
I0); I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
}); });
d0_block_copy_lds_to_global.MoveDstSliceWindow( d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment