Commit 1dc91af9 authored by wangshaojie6's avatar wangshaojie6
Browse files

add template to distinguish masking kernel

parent 7b18e6fd
...@@ -117,7 +117,8 @@ using DeviceGemmInstance = ...@@ -117,7 +117,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // OnlyLowerTriangle
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType,
......
...@@ -168,6 +168,7 @@ template <typename ALayout, ...@@ -168,6 +168,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool OnlyLowerTriangle = false,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout, : public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
...@@ -498,7 +499,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -498,7 +499,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN>; matrix_padder.PadN,
OnlyLowerTriangle>;
// Argument // Argument
// FIXME: constness // FIXME: constness
......
...@@ -76,7 +76,8 @@ template <typename FloatAB, ...@@ -76,7 +76,8 @@ template <typename FloatAB,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN> bool PadN,
bool OnlyLowerTriangle = false>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -756,8 +757,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -756,8 +757,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// decoder lower triangular mask // decoder lower triangular mask
const auto thread_cluster_idx = const auto thread_cluster_idx =
threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[Number<1>{}]; const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave; const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave; const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id; const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
...@@ -766,9 +767,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -766,9 +767,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
{ {
if((m_block_data_idx_on_grid < gemm1_k_block_outer_index * NPerBlock) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm1_k_block_outer_index * NPerBlock + NPerBlock - 1))) if constexpr(OnlyLowerTriangle)
{ {
continue; auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1)))
{
continue;
}
} }
// gemm0 // gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
...@@ -787,6 +792,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -787,6 +792,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf, acc_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
if constexpr(!OnlyLowerTriangle)
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
}
else
{
const index_t nstart = gemm1_k_block_outer_index * NPerBlock; const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) { static_for<0, m0, 1>{}([&](auto m0_i) {
...@@ -821,6 +841,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -821,6 +841,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}); });
}); });
}); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment