Commit 321b6c8e authored by ltqin's avatar ltqin
Browse files

change enum to MaskUpperTringleFrom

parent b4514459
......@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTringleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight)
{
return MaskUpperTringleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -628,7 +632,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -828,7 +832,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
......
......@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTringleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight)
{
return MaskUpperTringleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -636,7 +640,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -836,7 +840,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
......
......@@ -401,9 +401,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTringleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight)
{
return MaskUpperTringleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -531,7 +535,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -697,7 +701,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
......
......@@ -407,9 +407,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
return MaskOutUpperTrianglePredicate{};
return MaskUpperTringleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight)
{
return MaskUpperTringleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......@@ -537,7 +541,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
......@@ -708,7 +712,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp;
......
......@@ -10,7 +10,7 @@ namespace device {
enum struct MaskingSpecialization
{
MaskDisabled,
MaskOutUpperTriangle,
MaskUpperTringleFromTopLeft,
MaskUpperTringleFromBottomRight
};
......@@ -19,7 +19,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
switch(s)
{
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
case MaskingSpecialization::MaskUpperTringleFromTopLeft: return "MaskUpperTringleFromTopLeft";
case MaskingSpecialization::MaskUpperTringleFromBottomRight:
return "MaskUpperTringleFromBottomRight";
default: return "Unrecognized specialization!";
......@@ -40,7 +40,7 @@ struct MaskDisabledPredicate
}
};
struct MaskOutUpperTrianglePredicate
struct MaskUpperTringleFromTopLeftPredicate
{
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
......
......@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
......@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
......
......@@ -57,7 +57,7 @@ template <typename ALayout,
typename B0DataType,
typename B1DataType,
typename CDataType,
bool MaskOutUpperTriangle>
bool MaskUpperTringleFromTopLeft>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
......@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
Scale,
PassThrough,
PassThrough,
MaskOutUpperTriangle>>
MaskUpperTringleFromTopLeft>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
......@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
Scale,
PassThrough,
PassThrough,
MaskOutUpperTriangle>;
MaskUpperTringleFromTopLeft>;
static auto GetInstances()
{
......@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{
if constexpr(MaskOutUpperTriangle)
if constexpr(MaskUpperTringleFromTopLeft)
{
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
......
......@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
......@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
......
......@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances)
{
add_device_operation_instances(
......@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
MaskingSpecialization::MaskUpperTringleFromTopLeft>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......
......@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances)
{
add_device_operation_instances(
......@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
MaskingSpecialization::MaskUpperTringleFromTopLeft>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......
......@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances)
{
add_device_operation_instances(
......@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
MaskingSpecialization::MaskUpperTringleFromTopLeft>{});
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......
......@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>&
instances)
{
add_device_operation_instances(
......@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
MaskingSpecialization::MaskUpperTringleFromTopLeft>{});
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......
......@@ -241,7 +241,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
});
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
if(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
......
......@@ -31,7 +31,7 @@ template <typename ADataType,
typename B0Layout,
typename B1Layout,
typename CLayout,
bool MaskOutUpperTriangle>
bool MaskUpperTringleFromTopLeft>
bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int init_method,
bool do_log,
......@@ -197,20 +197,21 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskOutUpperTriangle>;
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskUpperTringleFromTopLeft>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
......@@ -229,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskOutUpperTriangle && idx[1] < idx[2])
if(MaskUpperTringleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
......
......@@ -219,7 +219,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
if(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
......
......@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
using MaskUpperTringleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskOutUpperTriangle_t>
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t>
>;
// clang-format on
......
......@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
using MaskUpperTringleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskOutUpperTriangle_t>
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t>
>;
// clang-format on
......
......@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
MaskingSpecialization::MaskUpperTringleFromTopLeft>; // MaskUpperTringleFromTopLeft
bool IsSupported(int M, int N, int K, int O)
{
......@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
MaskingSpecialization::MaskUpperTringleFromTopLeft>; // MaskUpperTringleFromTopLeft
bool IsSupported(int M, int N, int K, int O)
{
......
......@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
using MaskUpperTringleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t>
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t>
>;
// clang-format on
......
......@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
using MaskUpperTringleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t>
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t>
>;
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment