Commit 5938d555 authored by ltqin's avatar ltqin
Browse files

add DDattype and DKPerBlock parameter to device

parent 7c686fc2
...@@ -284,6 +284,7 @@ template <index_t NumDimG, ...@@ -284,6 +284,7 @@ template <index_t NumDimG,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename DDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -314,6 +315,7 @@ template <index_t NumDimG, ...@@ -314,6 +315,7 @@ template <index_t NumDimG,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave, index_t Gemm2NXdlPerWave,
index_t DKPerBlock,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -347,8 +349,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -347,8 +349,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DDataType = GemmAccDataType;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -764,7 +765,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -764,7 +765,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
BlockSize, BlockSize,
32>; DKPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -1161,7 +1162,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1161,7 +1162,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
......
...@@ -283,6 +283,7 @@ template <index_t NumDimG, ...@@ -283,6 +283,7 @@ template <index_t NumDimG,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename DDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -313,6 +314,7 @@ template <index_t NumDimG, ...@@ -313,6 +314,7 @@ template <index_t NumDimG,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave, index_t Gemm2NXdlPerWave,
index_t DKPerBlock,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -353,8 +355,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -353,8 +355,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DDataType = GemmAccDataType;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -778,7 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -778,7 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
DGridDesc_M, DGridDesc_M,
BlockSize, BlockSize,
BlockSize, BlockSize,
64>; DKPerBlock>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -1188,7 +1189,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1188,7 +1189,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
......
...@@ -36,6 +36,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -36,6 +36,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
static_assert(BlockSize == MPerBlock, "BlockSize must be same with MPerBlock");
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> template <typename Block2CTileMap>
...@@ -46,7 +47,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad ...@@ -46,7 +47,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{ {
return false; return false;
} }
const auto M = c_grid_desc_m_n.GetLength(I0);
if(M < MPerBlock)
{
return false;
}
if(M % MPerBlock != 0)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment