Commit d173a2cb authored by letaoqin's avatar letaoqin
Browse files

batched gemm add do vector load

parent 9679ba63
......@@ -136,6 +136,7 @@ using DeviceGemmInstance =
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
......@@ -147,6 +148,7 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
......@@ -207,6 +209,7 @@ using DeviceGemmInstance =
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
......@@ -218,6 +221,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128)
......@@ -278,6 +282,7 @@ using DeviceGemmInstance =
8,
8,
true,
4,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
......@@ -289,6 +294,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif
......
......@@ -274,6 +274,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t Acc0BiasTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
......@@ -285,6 +286,7 @@ template <index_t NumDimG,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
......@@ -347,6 +349,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BSpec,
B1Spec,
CSpec>;
using RawTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::Default,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
......@@ -552,6 +562,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
......@@ -564,6 +575,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Acc1BiasTransferSrcScalarPerVector,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled,
......@@ -670,7 +682,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_g_m_n_,
d_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
raw_d0_n_(0)
{
// TODO ANT: implement bias addition
ignore = p_acc1_biases;
......@@ -709,6 +722,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{
is_lse_storing_ = false;
}
if constexpr(NumD0Tensor)
{
const auto d0_grid_desc_m_n = RawTransform::MakeCGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths[0], acc0_biases_gs_ms_ns_strides[0]);
raw_d0_n_ = d0_grid_desc_m_n.GetLength(I1);
}
}
void Print() const
......@@ -794,6 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
index_t m_raw_padded_;
index_t n_raw_padded_;
// raw data
int raw_d0_n_;
};
// Invoker
......@@ -1000,6 +1022,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return false;
}
if(arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment