Commit 5af78ac2 authored by ltqin's avatar ltqin
Browse files

fix triagle name

parent 4a653a5d
...@@ -59,7 +59,7 @@ using CElementOp = PassThrough; ...@@ -59,7 +59,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
......
...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromBottomRight; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromBottomRight;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -58,7 +58,7 @@ using CElementOp = PassThrough; ...@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
......
...@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -319,13 +319,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -319,13 +319,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -364,7 +364,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -364,7 +364,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0MatrixMask = conditional_t<MaskOutUpperTriangle, using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
C0MatrixMask_impl<MaskUpperTringleFromTopLeftPredicate>, C0MatrixMask_impl<MaskUpperTriangleFromTopLeftPredicate>,
C0MatrixMask_impl<MaskDisabledPredicate>>; C0MatrixMask_impl<MaskDisabledPredicate>>;
// GridwiseGemm // GridwiseGemm
......
...@@ -564,13 +564,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -564,13 +564,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -570,13 +570,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -570,13 +570,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -559,13 +559,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -559,13 +559,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -565,13 +565,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -565,13 +565,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -386,13 +386,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -386,13 +386,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -394,13 +394,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -394,13 +394,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -291,13 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -291,13 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment