Commit 7ae26b79 authored by wangshaojie6's avatar wangshaojie6
Browse files

rename template and remove default template value

parent 1dc91af9
......@@ -118,7 +118,7 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // OnlyLowerTriangle
true>; // MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemmUpperTriangleMinusInf<ADataType,
......
......@@ -117,7 +117,8 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
false>; // MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
......@@ -168,7 +168,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool OnlyLowerTriangle = false,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
......@@ -500,7 +500,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
OnlyLowerTriangle>;
MaskOutUpperTriangle>;
// Argument
// FIXME: constness
......
......@@ -77,7 +77,7 @@ template <typename FloatAB,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool OnlyLowerTriangle = false>
bool MaskOutUpperTriangle>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
......@@ -767,7 +767,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t gemm1_k_block_outer_index = 0;
do
{
if constexpr(OnlyLowerTriangle)
if constexpr(MaskOutUpperTriangle)
{
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1)))
......@@ -792,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf,
num_k_block_main_loop);
if constexpr(!OnlyLowerTriangle)
if constexpr(!MaskOutUpperTriangle)
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment