Unverified Commit 5f4e6ec0 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1060 from ROCmSoftwarePlatform/mha-train-develop-fix-d-calculate

Fix y and y_grad data vector load size for flash attention
parents 2f93e26f 07f07581
...@@ -852,14 +852,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -852,14 +852,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
Deterministic>; Deterministic>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
DDataType, InputDataType,
DYGridDesc_M_O, DDataType,
DGridDesc_M, DYGridDesc_M_O,
BlockSize, DGridDesc_M,
DMPerBlock, BlockSize,
DKPerBlock, DMPerBlock,
Gemm1NPerBlock>; DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
......
...@@ -869,14 +869,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -869,14 +869,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
Deterministic>; Deterministic>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
DDataType, InputDataType,
DYGridDesc_M_O, DDataType,
DGridDesc_M, DYGridDesc_M_O,
BlockSize, DGridDesc_M,
DMPerBlock, BlockSize,
DKPerBlock, DMPerBlock,
Gemm1NPerBlock>; DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
......
...@@ -821,14 +821,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -821,14 +821,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
DDataType, InputDataType,
DYGridDesc_M_O, DDataType,
DGridDesc_M, DYGridDesc_M_O,
BlockSize, DGridDesc_M,
DMPerBlock, BlockSize,
DKPerBlock, DMPerBlock,
Gemm1NPerBlock>; DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
......
...@@ -890,14 +890,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -890,14 +890,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad // GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType, using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
DDataType, InputDataType,
DYGridDesc_M_O, DDataType,
DGridDesc_M, DYGridDesc_M_O,
BlockSize, DGridDesc_M,
DMPerBlock, BlockSize,
DKPerBlock, DMPerBlock,
Gemm1NPerBlock>; DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using DBlock2CTileMap = using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>; OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
......
...@@ -63,6 +63,12 @@ struct PassThrough ...@@ -63,6 +63,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
template <>
__host__ __device__ void operator()<half_t, bhalf_t>(half_t& y, const bhalf_t& x) const
{
y = type_convert<half_t>(x);
}
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
......
...@@ -1250,6 +1250,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1250,6 +1250,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, ""); static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % CShuffleBlockTransferScalarPerVector_NPerBlock == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc, FloatGemmAcc,
...@@ -2007,9 +2008,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2007,9 +2008,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, // SrcVectorDim 3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(y_grid_desc_mblock_mperblock_oblock_operblock, true /* ResetCoordAfterRun */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx); y_thread_data_on_grid_idx);
...@@ -2021,9 +2022,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2021,9 +2022,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype(ygrad_thread_desc_m_o), decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()), decltype(ygrad_thread_desc_m_o.GetLengths()),
Sequence<0, 1>, Sequence<0, 1>,
1, // SrcVectorDim 1, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(YDotYGrad_M_O::ygrad_block_desc_m_o, true /* ResetCoordAfterRun */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_thread_data_on_block_idx); ygrad_thread_data_on_block_idx);
......
...@@ -1240,6 +1240,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1240,6 +1240,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, ""); static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % CShuffleBlockTransferScalarPerVector_NPerBlock == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc, FloatGemmAcc,
...@@ -2102,9 +2103,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2102,9 +2103,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, // SrcVectorDim 3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */, true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock, false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx); y_thread_data_on_grid_idx);
......
...@@ -27,7 +27,8 @@ template <typename InputDataType, ...@@ -27,7 +27,8 @@ template <typename InputDataType,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t NPadded> index_t NPadded,
index_t YSrcScalarPerVector>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -125,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -125,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, ""); static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % YSrcScalarPerVector == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatD, FloatD,
...@@ -194,19 +196,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -194,19 +196,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
y_thread_data_on_block_idx; y_thread_data_on_block_idx;
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy =
InputDataType, ThreadwiseTensorSliceTransfer_v2<InputDataType,
FloatD, FloatD,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(y_thread_desc_m0_m1_n0_n1), decltype(y_thread_desc_m0_m1_n0_n1),
decltype(y_thread_desc_m0_m1_n0_n1.GetLengths()), decltype(y_thread_desc_m0_m1_n0_n1.GetLengths()),
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
3, // SrcVectorDim 3, // SrcVectorDim
YDotYGrad_M_N::SrcScalarPerVector, // SrcScalarPerVector YSrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */, true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_nblock_nperblock, false /* InvalidElementAsNaN */>(
y_thread_data_on_grid_idx); y_grid_desc_mblock_mperblock_nblock_nperblock, y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_N::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_N::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_N::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_N::SrcBufType{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment