Commit 1dc91af9 authored by wangshaojie6's avatar wangshaojie6
Browse files

add template to distinguish masking kernel

parent 7b18e6fd
......@@ -117,7 +117,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // OnlyLowerTriangle
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType,
......
......@@ -168,6 +168,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool OnlyLowerTriangle = false,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
......@@ -498,7 +499,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN>;
matrix_padder.PadN,
OnlyLowerTriangle>;
// Argument
// FIXME: constness
......
......@@ -76,7 +76,8 @@ template <typename FloatAB,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN>
bool PadN,
bool OnlyLowerTriangle = false>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -756,8 +757,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// decoder lower triangular mask
const auto thread_cluster_idx =
threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_n_cluster_id = thread_cluster_idx[Number<1>{}];
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
......@@ -766,9 +767,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t gemm1_k_block_outer_index = 0;
do
{
if((m_block_data_idx_on_grid < gemm1_k_block_outer_index * NPerBlock) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm1_k_block_outer_index * NPerBlock + NPerBlock - 1)))
if constexpr(OnlyLowerTriangle)
{
continue;
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1)))
{
continue;
}
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
......@@ -787,6 +792,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf,
num_k_block_main_loop);
if constexpr(!OnlyLowerTriangle)
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
}
else
{
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) {
......@@ -821,6 +841,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
});
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment