Commit 7ae26b79 authored by wangshaojie6's avatar wangshaojie6
Browse files

rename template and remove default template value

parent 1dc91af9
...@@ -118,7 +118,7 @@ using DeviceGemmInstance = ...@@ -118,7 +118,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // OnlyLowerTriangle true>; // MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType,
......
...@@ -117,7 +117,8 @@ using DeviceGemmInstance = ...@@ -117,7 +117,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
false>; // MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
...@@ -168,7 +168,7 @@ template <typename ALayout, ...@@ -168,7 +168,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool OnlyLowerTriangle = false, bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout, : public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
...@@ -500,7 +500,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -500,7 +500,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN, matrix_padder.PadN,
OnlyLowerTriangle>; MaskOutUpperTriangle>;
// Argument // Argument
// FIXME: constness // FIXME: constness
......
...@@ -77,7 +77,7 @@ template <typename FloatAB, ...@@ -77,7 +77,7 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool OnlyLowerTriangle = false> bool MaskOutUpperTriangle>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -767,7 +767,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -767,7 +767,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
{ {
if constexpr(OnlyLowerTriangle) if constexpr(MaskOutUpperTriangle)
{ {
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1))) if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1)))
...@@ -792,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -792,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf, acc_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
if constexpr(!OnlyLowerTriangle) if constexpr(!MaskOutUpperTriangle)
{ {
// Acc0 elementwise Op // Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER #if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment