Unverified Commit 2018bd28 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #811 from ROCmSoftwarePlatform/attn-train-develop-qloop-mask

Add another mask(upper tringle from bottom right) to flash attetion
parents 120760d6 a822937a
...@@ -219,7 +219,8 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -219,7 +219,8 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2]) if(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft &&
idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
......
...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>; ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskOutUpperTriangle_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>; ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskOutUpperTriangle_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
...@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128 ...@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle MaskingSpecialization::MaskUpperTriangleFromTopLeft>; // MaskUpperTriangleFromTopLeft
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
...@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128 ...@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle MaskingSpecialization::MaskUpperTriangleFromTopLeft>; // MaskUpperTriangleFromTopLeft
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
......
...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>; ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>; ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
...@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128 ...@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle MaskingSpecialization::MaskUpperTriangleFromTopLeft>; // MaskUpperTriangleFromTopLeft
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
...@@ -315,7 +315,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128 ...@@ -315,7 +315,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle MaskingSpecialization::MaskUpperTriangleFromTopLeft>; // MaskUpperTriangleFromTopLeft
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment