Commit 5d6bfabb authored by letaoqin's avatar letaoqin
Browse files

add d vector load template parameters

parent 8e7b98eb
...@@ -136,6 +136,7 @@ using DeviceGemmInstance = ...@@ -136,6 +136,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
1,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -147,6 +148,7 @@ using DeviceGemmInstance = ...@@ -147,6 +148,7 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1,
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
...@@ -207,6 +209,7 @@ using DeviceGemmInstance = ...@@ -207,6 +209,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
1,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -218,6 +221,7 @@ using DeviceGemmInstance = ...@@ -218,6 +221,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1,
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
...@@ -278,6 +282,7 @@ using DeviceGemmInstance = ...@@ -278,6 +282,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
1,
S<8, 32, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
...@@ -289,6 +294,7 @@ using DeviceGemmInstance = ...@@ -289,6 +294,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
1,
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#endif #endif
......
...@@ -256,6 +256,7 @@ template <index_t NumDimG, ...@@ -256,6 +256,7 @@ template <index_t NumDimG,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, bool BBlockLdsExtraN,
index_t Acc0BiasTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -267,6 +268,7 @@ template <index_t NumDimG, ...@@ -267,6 +268,7 @@ template <index_t NumDimG,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t Acc1BiasTransferSrcScalarPerVector,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
...@@ -561,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -561,6 +563,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
true, true,
Acc0BiasTransferSrcScalarPerVector,
BBlockLdsExtraN, BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
...@@ -574,6 +577,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -574,6 +577,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
Acc1BiasTransferSrcScalarPerVector,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec != MaskingSpecialization::MaskDisabled, MaskingSpec != MaskingSpecialization::MaskDisabled,
......
...@@ -74,6 +74,7 @@ template <typename FloatAB, ...@@ -74,6 +74,7 @@ template <typename FloatAB,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN, index_t BBlockLdsExtraN,
index_t D0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder, typename B1BlockTransferSrcAccessOrder,
...@@ -86,6 +87,7 @@ template <typename FloatAB, ...@@ -86,6 +87,7 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t D1BlockTransferSrcScalarPerVector,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
...@@ -930,7 +932,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -930,7 +932,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
n4>, n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, 9,
4, D0BlockTransferSrcScalarPerVector,
1, 1,
false>(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, false>(d_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx[I0], // MBlockId
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment