Commit 5938d555 authored by ltqin's avatar ltqin
Browse files

add DDattype and DKPerBlock parameter to device

parent 7c686fc2
......@@ -284,6 +284,7 @@ template <index_t NumDimG,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename DDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
......@@ -314,6 +315,7 @@ template <index_t NumDimG,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
index_t DKPerBlock,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -348,7 +350,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DDataType = GemmAccDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -764,7 +765,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
DGridDesc_M,
BlockSize,
BlockSize,
32>;
DKPerBlock>;
// Argument
struct Argument : public BaseArgument
{
......@@ -1161,7 +1162,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
// TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
......
......@@ -283,6 +283,7 @@ template <index_t NumDimG,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename DDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
......@@ -313,6 +314,7 @@ template <index_t NumDimG,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
index_t DKPerBlock,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -354,7 +356,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DDataType = GemmAccDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -778,7 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
DGridDesc_M,
BlockSize,
BlockSize,
64>;
DKPerBlock>;
// Argument
struct Argument : public BaseArgument
{
......@@ -1188,7 +1189,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
// TODO: Check if tensor specialization & strides mismatch
if(!GridwiseYDotYGrad::CheckValidity(arg.y_grid_desc_m_o_, arg.d_block_2_ctile_map_))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
......
......@@ -36,6 +36,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static constexpr auto I4 = Number<4>{};
static constexpr auto WaveSize = 64;
static_assert(BlockSize == MPerBlock, "BlockSize must be same with MPerBlock");
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
......@@ -46,7 +47,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
return false;
}
const auto M = c_grid_desc_m_n.GetLength(I0);
if(M < MPerBlock)
{
return false;
}
if(M % MPerBlock != 0)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment