Unverified Commit 2018bd28 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #811 from ROCmSoftwarePlatform/attn-train-develop-qloop-mask

Add another mask(upper tringle from bottom right) to flash attetion
parents 120760d6 a822937a
...@@ -577,9 +577,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -577,9 +577,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -709,7 +713,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -709,7 +713,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
// Argument // Argument
...@@ -793,7 +797,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -793,7 +797,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
...@@ -395,9 +395,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -395,9 +395,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -524,7 +528,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -524,7 +528,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
// Argument // Argument
...@@ -597,7 +601,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -597,7 +601,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
...@@ -405,9 +405,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -405,9 +405,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -534,7 +538,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -534,7 +538,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
// Argument // Argument
...@@ -607,7 +611,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -607,7 +611,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, c0_matrix_mask_{a_grid_desc_g_m_k_.GetLength(I1), b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
......
...@@ -291,9 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -291,9 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -399,7 +403,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -399,7 +403,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec != MaskingSpecialization::MaskDisabled>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -527,7 +531,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -527,7 +531,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n); a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -500,9 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -500,9 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -622,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -622,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -818,7 +822,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -818,7 +822,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -500,9 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -500,9 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -630,7 +634,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -630,7 +634,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -826,7 +830,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -826,7 +830,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -628,7 +632,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -628,7 +632,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -828,7 +832,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -828,7 +832,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -636,7 +640,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -636,7 +640,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -836,7 +840,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -836,7 +840,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -405,9 +405,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -405,9 +405,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -535,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -535,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -701,7 +705,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -701,7 +705,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -411,9 +411,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -411,9 +411,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskOutUpperTrianglePredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{
return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
...@@ -541,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -541,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec != MaskingSpecialization::MaskDisabled,
Deterministic>; Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -712,7 +716,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -712,7 +716,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask =
C0MatrixMask(a_grid_desc_g_m_k.GetLength(I1), b_grid_desc_g_n_k.GetLength(I1));
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -10,7 +10,8 @@ namespace device { ...@@ -10,7 +10,8 @@ namespace device {
enum struct MaskingSpecialization enum struct MaskingSpecialization
{ {
MaskDisabled, MaskDisabled,
MaskOutUpperTriangle MaskUpperTriangleFromTopLeft,
MaskUpperTriangleFromBottomRight
}; };
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
...@@ -18,7 +19,9 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s ...@@ -18,7 +19,9 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
switch(s) switch(s)
{ {
case MaskingSpecialization::MaskDisabled: return "MaskDisabled"; case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle"; case MaskingSpecialization::MaskUpperTriangleFromTopLeft: return "MaskUpperTriangleFromTopLeft";
case MaskingSpecialization::MaskUpperTriangleFromBottomRight:
return "MaskUpperTriangleFromBottomRight";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
...@@ -37,7 +40,7 @@ struct MaskDisabledPredicate ...@@ -37,7 +40,7 @@ struct MaskDisabledPredicate
} }
}; };
struct MaskOutUpperTrianglePredicate struct MaskUpperTriangleFromTopLeftPredicate
{ {
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
...@@ -48,12 +51,50 @@ struct MaskOutUpperTrianglePredicate ...@@ -48,12 +51,50 @@ struct MaskOutUpperTrianglePredicate
} }
}; };
// eg: m = 3, n = 5 => offset = 2
// so matrix(n > m + offset) = 0
// 1 2 3 4 5
// 1 * * * 0 0
// 2 * * * * 0
// 3 * * * * *
struct MaskUpperTriangleFromBottomRightPredicate
{
MaskUpperTriangleFromBottomRightPredicate() : diagonal_offset_(0) {}
__host__ __device__ void SetDiagonalOffset(const index_t diagonal_offset)
{
diagonal_offset_ = diagonal_offset;
}
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const
{
return n > (m + diagonal_offset_);
}
__host__ __device__ constexpr bool IsTileSkippable(index_t m_tile_orig,
index_t n_tile_orig,
index_t m_tile_size,
index_t /*n_tile_size*/) const
{
return operator()(m_tile_orig + m_tile_size - 1, n_tile_orig);
}
private:
index_t diagonal_offset_;
};
// to track the points which need to be set to -inf on C0 // to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out. // Note: no need to reset M padding value, because they will not be stored out.
template <typename MaskOutPredicate> template <typename MaskOutPredicate>
struct C0MatrixMask_impl struct C0MatrixMask_impl
{ {
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} C0MatrixMask_impl(index_t MRaw, index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
if constexpr(std::is_same<MaskOutPredicate,
MaskUpperTriangleFromBottomRightPredicate>::value)
{
if(NRaw > MRaw)
predicate_.SetDiagonalOffset(NRaw - MRaw);
}
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
......
...@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_ ...@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
...@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16 ...@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
...@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory< ...@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>) is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
...@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory< ...@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>) is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
......
...@@ -57,7 +57,7 @@ template <typename ALayout, ...@@ -57,7 +57,7 @@ template <typename ALayout,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
bool MaskOutUpperTriangle> bool MaskUpperTriangleFromTopLeft>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory< ...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskOutUpperTriangle>> MaskUpperTriangleFromTopLeft>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory< ...@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskOutUpperTriangle>; MaskUpperTriangleFromTopLeft>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory< ...@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>) is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{ {
if constexpr(MaskOutUpperTriangle) if constexpr(MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs); op_ptrs);
......
...@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f ...@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
...@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16 ...@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
...@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory< ...@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
...@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory< ...@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>) is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
......
...@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16 ...@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16 ...@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskOutUpperTriangle>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......
...@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_ ...@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_ ...@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskOutUpperTriangle>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......
...@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16 ...@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16 ...@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskOutUpperTriangle>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......
...@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f ...@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f ...@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskOutUpperTriangle>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......
...@@ -241,7 +241,8 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -241,7 +241,8 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
}); });
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2]) if(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft &&
idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
......
...@@ -31,7 +31,7 @@ template <typename ADataType, ...@@ -31,7 +31,7 @@ template <typename ADataType,
typename B0Layout, typename B0Layout,
typename B1Layout, typename B1Layout,
typename CLayout, typename CLayout,
bool MaskOutUpperTriangle> bool MaskUpperTriangleFromTopLeft>
bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -197,20 +197,21 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -197,20 +197,21 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout, using DeviceOp =
B0Layout, tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B1Layout, B0Layout,
CLayout, B1Layout,
ADataType, CLayout,
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
AElementOp, CDataType,
B0ElementOp, AElementOp,
Acc0ElementOp, B0ElementOp,
B1ElementOp, Acc0ElementOp,
CElementOp, B1ElementOp,
MaskOutUpperTriangle>; CElementOp,
MaskUpperTriangleFromTopLeft>;
// get device op instances // get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -229,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -229,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskOutUpperTriangle && idx[1] < idx[2]) if(MaskUpperTriangleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment