Commit a3f11fe9 authored by ltqin's avatar ltqin
Browse files

fix name

parent 92238f48
...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; ...@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromBottonRight; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromBottomRight;
#else #else
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
...@@ -569,9 +569,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -569,9 +569,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
return MaskOutUpperTrianglePredicate{}; return MaskOutUpperTrianglePredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottonRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight)
{ {
return MaskUpperTringleFromBottonRightPredicate{}; return MaskUpperTringleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -11,7 +11,7 @@ enum struct MaskingSpecialization ...@@ -11,7 +11,7 @@ enum struct MaskingSpecialization
{ {
MaskDisabled, MaskDisabled,
MaskOutUpperTriangle, MaskOutUpperTriangle,
MaskUpperTringleFromBottonRight MaskUpperTringleFromBottomRight
}; };
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
...@@ -20,8 +20,8 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s ...@@ -20,8 +20,8 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
{ {
case MaskingSpecialization::MaskDisabled: return "MaskDisabled"; case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle"; case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
case MaskingSpecialization::MaskUpperTringleFromBottonRight: case MaskingSpecialization::MaskUpperTringleFromBottomRight:
return "MaskUpperTringleFromBottonRight"; return "MaskUpperTringleFromBottomRight";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
...@@ -50,9 +50,9 @@ struct MaskOutUpperTrianglePredicate ...@@ -50,9 +50,9 @@ struct MaskOutUpperTrianglePredicate
return operator()(m + m_tile - 1, n); return operator()(m + m_tile - 1, n);
} }
}; };
struct MaskUpperTringleFromBottonRightPredicate struct MaskUpperTringleFromBottomRightPredicate
{ {
MaskUpperTringleFromBottonRightPredicate() : offset_(0) {} MaskUpperTringleFromBottomRightPredicate() : offset_(0) {}
__host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; } __host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; }
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const __host__ __device__ constexpr bool operator()(index_t m, index_t n) const
{ {
...@@ -77,7 +77,7 @@ struct C0MatrixMask_impl ...@@ -77,7 +77,7 @@ struct C0MatrixMask_impl
C0MatrixMask_impl(index_t MRaw, index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) C0MatrixMask_impl(index_t MRaw, index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{})
{ {
if constexpr(std::is_same<MaskOutPredicate, if constexpr(std::is_same<MaskOutPredicate,
MaskUpperTringleFromBottonRightPredicate>::value) MaskUpperTringleFromBottomRightPredicate>::value)
{ {
if(NRaw > MRaw) if(NRaw > MRaw)
predicate_.SetOffset(NRaw - MRaw); predicate_.SetOffset(NRaw - MRaw);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment