Commit a3f11fe9 authored by ltqin's avatar ltqin
Browse files

fix name

parent 92238f48
......@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromBottonRight;
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTringleFromBottomRight;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
......
......@@ -569,9 +569,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
return MaskOutUpperTrianglePredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottonRight)
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight)
{
return MaskUpperTringleFromBottonRightPredicate{};
return MaskUpperTringleFromBottomRightPredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
......@@ -11,7 +11,7 @@ enum struct MaskingSpecialization
{
MaskDisabled,
MaskOutUpperTriangle,
MaskUpperTringleFromBottonRight
MaskUpperTringleFromBottomRight
};
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
......@@ -20,8 +20,8 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
{
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
case MaskingSpecialization::MaskUpperTringleFromBottonRight:
return "MaskUpperTringleFromBottonRight";
case MaskingSpecialization::MaskUpperTringleFromBottomRight:
return "MaskUpperTringleFromBottomRight";
default: return "Unrecognized specialization!";
}
}
......@@ -50,9 +50,9 @@ struct MaskOutUpperTrianglePredicate
return operator()(m + m_tile - 1, n);
}
};
struct MaskUpperTringleFromBottonRightPredicate
struct MaskUpperTringleFromBottomRightPredicate
{
MaskUpperTringleFromBottonRightPredicate() : offset_(0) {}
MaskUpperTringleFromBottomRightPredicate() : offset_(0) {}
__host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; }
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const
{
......@@ -77,7 +77,7 @@ struct C0MatrixMask_impl
C0MatrixMask_impl(index_t MRaw, index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
if constexpr(std::is_same<MaskOutPredicate,
MaskUpperTringleFromBottonRightPredicate>::value)
MaskUpperTringleFromBottomRightPredicate>::value)
{
if(NRaw > MRaw)
predicate_.SetOffset(NRaw - MRaw);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment