Commit d173a2cb authored by letaoqin's avatar letaoqin
Browse files

batched gemm add do vector load

parent 9679ba63
...@@ -136,6 +136,7 @@ using DeviceGemmInstance = ...@@ -136,6 +136,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
4,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -147,6 +148,7 @@ using DeviceGemmInstance = ...@@ -147,6 +148,7 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
...@@ -207,6 +209,7 @@ using DeviceGemmInstance = ...@@ -207,6 +209,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
4,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -218,6 +221,7 @@ using DeviceGemmInstance = ...@@ -218,6 +221,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
...@@ -278,6 +282,7 @@ using DeviceGemmInstance = ...@@ -278,6 +282,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
4,
S<8, 32, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -289,6 +294,7 @@ using DeviceGemmInstance = ...@@ -289,6 +294,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
4,
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#endif #endif
......
...@@ -274,6 +274,7 @@ template <index_t NumDimG, ...@@ -274,6 +274,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t Acc0BiasTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -285,6 +286,7 @@ template <index_t NumDimG, ...@@ -285,6 +286,7 @@ template <index_t NumDimG,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
...@@ -347,6 +349,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -347,6 +349,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BSpec, BSpec,
B1Spec, B1Spec,
CSpec>; CSpec>;
using RawTransform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpecialization::Default,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
...@@ -552,6 +562,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -552,6 +562,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
BBlockLdsExtraN, BBlockLdsExtraN,
Acc0BiasTransferSrcScalarPerVector,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -564,6 +575,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -564,6 +575,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled,
...@@ -670,7 +682,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -670,7 +682,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
d_grid_desc_g_m_n_, d_grid_desc_g_m_n_,
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
raw_d0_n_(0)
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc1_biases; ignore = p_acc1_biases;
...@@ -709,6 +722,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -709,6 +722,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
} }
if constexpr(NumD0Tensor)
{
const auto d0_grid_desc_m_n = RawTransform::MakeCGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths[0], acc0_biases_gs_ms_ns_strides[0]);
raw_d0_n_ = d0_grid_desc_m_n.GetLength(I1);
}
} }
void Print() const void Print() const
...@@ -794,6 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -794,6 +813,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
index_t m_raw_padded_; index_t m_raw_padded_;
index_t n_raw_padded_; index_t n_raw_padded_;
// raw data
int raw_d0_n_;
}; };
// Invoker // Invoker
...@@ -1000,6 +1022,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1000,6 +1022,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
return false; return false;
} }
if(arg.raw_d0_n_ % Acc0BiasTransferSrcScalarPerVector != 0)
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds // vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment