Unverified Commit 5f4e6ec0 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1060 from ROCmSoftwarePlatform/mha-train-develop-fix-d-calculate

Fix y and y_grad data vector load size for flash attention
parents 2f93e26f 07f07581
......@@ -852,14 +852,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
Deterministic>;
// GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock>;
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
// Argument
struct Argument : public BaseArgument
{
......
......@@ -869,14 +869,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
Deterministic>;
// GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock>;
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
// Argument
struct Argument : public BaseArgument
{
......
......@@ -821,14 +821,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock>;
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
......
......@@ -890,14 +890,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
// GridwiseYDotYGrad
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock>;
using GridwiseYDotYGrad = GridwiseBatchedMultiheadAttentionBackward_YDotYGrad<
InputDataType,
DDataType,
DYGridDesc_M_O,
DGridDesc_M,
BlockSize,
DMPerBlock,
DKPerBlock,
Gemm1NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using DBlock2CTileMap =
OffsettedBlockToCTileMap<typename GridwiseYDotYGrad::DefaultBlock2CTileMap>;
......
......@@ -63,6 +63,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<half_t, bhalf_t>(half_t& y, const bhalf_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
......
......@@ -1250,6 +1250,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % CShuffleBlockTransferScalarPerVector_NPerBlock == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
......@@ -2007,9 +2008,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
3, // SrcVectorDim
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
......@@ -2021,9 +2022,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
decltype(ygrad_thread_desc_m_o),
decltype(ygrad_thread_desc_m_o.GetLengths()),
Sequence<0, 1>,
1, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
1, // SrcVectorDim
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
ygrad_thread_data_on_block_idx);
......
......@@ -1240,6 +1240,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % CShuffleBlockTransferScalarPerVector_NPerBlock == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatGemmAcc,
......@@ -2102,9 +2103,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
3, // SrcVectorDim
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
......
......@@ -27,7 +27,8 @@ template <typename InputDataType,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t NPadded>
index_t NPadded,
index_t YSrcScalarPerVector>
struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
static constexpr auto I0 = Number<0>{};
......@@ -125,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
static_assert(SrcScalarPerVector % YSrcScalarPerVector == 0, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatD,
......@@ -194,19 +196,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
InputDataType,
FloatD,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(y_thread_desc_m0_m1_n0_n1),
decltype(y_thread_desc_m0_m1_n0_n1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_N::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_nblock_nperblock,
y_thread_data_on_grid_idx);
auto yygrad_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<InputDataType,
FloatD,
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
decltype(y_thread_desc_m0_m1_n0_n1),
decltype(y_thread_desc_m0_m1_n0_n1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YSrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
false /* InvalidElementAsNaN */>(
y_grid_desc_mblock_mperblock_nblock_nperblock, y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_N::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_N::SrcBufType{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment