Commit 5cc0fd88 authored by letaoqin's avatar letaoqin
Browse files

D0 data loading use D0BlockTransferSrcScalarPerVector

parent e938bd61
......@@ -101,7 +101,7 @@ using DeviceGemmInstance =
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
4, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
......@@ -121,13 +121,13 @@ using DeviceGemmInstance =
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
8,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
DIM / 32,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
......
......@@ -202,7 +202,6 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionInfer<NumDimG,
......
......@@ -88,11 +88,6 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -392,20 +387,18 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
struct D0Operator
{
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
static_assert(ABlockTransferDstScalarPerVector_AK1 % D0BlockTransferSrcScalarPerVector ==
0);
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
static constexpr index_t Size0 = sizeof(DataType);
static constexpr index_t Size = sizeof(DataType);
using Type = DataType;
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
using Type = ck::half_t;
};
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
......@@ -463,7 +456,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
Sequence<0, 1, 2, 4, 3, 5>,
5,
5,
ABlockTransferSrcScalarPerVector,
D0BlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment