"test/vscode:/vscode.git/clone" did not exist on "08eb17692978f2af709deb59c98aae6db4c82b6b"
Commit 5af78ac2 authored by ltqin's avatar ltqin
Browse files

fix triagle name

parent 4a653a5d
...@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -505,13 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -505,13 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -505,13 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -505,13 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -401,13 +401,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -401,13 +401,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -407,13 +407,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -407,13 +407,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
return MaskDisabledPredicate{}; return MaskDisabledPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
return MaskUpperTringleFromTopLeftPredicate{}; return MaskUpperTriangleFromTopLeftPredicate{};
} }
else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromBottomRight) else if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromBottomRight)
{ {
return MaskUpperTringleFromBottomRightPredicate{}; return MaskUpperTriangleFromBottomRightPredicate{};
} }
} }
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
......
...@@ -10,8 +10,8 @@ namespace device { ...@@ -10,8 +10,8 @@ namespace device {
enum struct MaskingSpecialization enum struct MaskingSpecialization
{ {
MaskDisabled, MaskDisabled,
MaskUpperTringleFromTopLeft, MaskUpperTriangleFromTopLeft,
MaskUpperTringleFromBottomRight MaskUpperTriangleFromBottomRight
}; };
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
...@@ -19,9 +19,9 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s ...@@ -19,9 +19,9 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
switch(s) switch(s)
{ {
case MaskingSpecialization::MaskDisabled: return "MaskDisabled"; case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskUpperTringleFromTopLeft: return "MaskUpperTringleFromTopLeft"; case MaskingSpecialization::MaskUpperTriangleFromTopLeft: return "MaskUpperTriangleFromTopLeft";
case MaskingSpecialization::MaskUpperTringleFromBottomRight: case MaskingSpecialization::MaskUpperTriangleFromBottomRight:
return "MaskUpperTringleFromBottomRight"; return "MaskUpperTriangleFromBottomRight";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
...@@ -40,7 +40,7 @@ struct MaskDisabledPredicate ...@@ -40,7 +40,7 @@ struct MaskDisabledPredicate
} }
}; };
struct MaskUpperTringleFromTopLeftPredicate struct MaskUpperTriangleFromTopLeftPredicate
{ {
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
...@@ -50,9 +50,9 @@ struct MaskUpperTringleFromTopLeftPredicate ...@@ -50,9 +50,9 @@ struct MaskUpperTringleFromTopLeftPredicate
return operator()(m + m_tile - 1, n); return operator()(m + m_tile - 1, n);
} }
}; };
struct MaskUpperTringleFromBottomRightPredicate struct MaskUpperTriangleFromBottomRightPredicate
{ {
MaskUpperTringleFromBottomRightPredicate() : offset_(0) {} MaskUpperTriangleFromBottomRightPredicate() : offset_(0) {}
__host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; } __host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; }
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const __host__ __device__ constexpr bool operator()(index_t m, index_t n) const
{ {
...@@ -77,7 +77,7 @@ struct C0MatrixMask_impl ...@@ -77,7 +77,7 @@ struct C0MatrixMask_impl
C0MatrixMask_impl(index_t MRaw, index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) C0MatrixMask_impl(index_t MRaw, index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{})
{ {
if constexpr(std::is_same<MaskOutPredicate, if constexpr(std::is_same<MaskOutPredicate,
MaskUpperTringleFromBottomRightPredicate>::value) MaskUpperTriangleFromBottomRightPredicate>::value)
{ {
if(NRaw > MRaw) if(NRaw > MRaw)
predicate_.SetOffset(NRaw - MRaw); predicate_.SetOffset(NRaw - MRaw);
......
...@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_ ...@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
...@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16 ...@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
...@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory< ...@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>) is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
...@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory< ...@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType::Size() == 1 && Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>) is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
......
...@@ -57,7 +57,7 @@ template <typename ALayout, ...@@ -57,7 +57,7 @@ template <typename ALayout,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
bool MaskUpperTringleFromTopLeft> bool MaskUpperTriangleFromTopLeft>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory< ...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskUpperTringleFromTopLeft>> MaskUpperTriangleFromTopLeft>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout, B0Layout,
...@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory< ...@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskUpperTringleFromTopLeft>; MaskUpperTriangleFromTopLeft>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory< ...@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>) is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{ {
if constexpr(MaskUpperTringleFromTopLeft) if constexpr(MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs); op_ptrs);
......
...@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f ...@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
...@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16 ...@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances); instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
...@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory< ...@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
...@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory< ...@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> && else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>) is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs); op_ptrs);
......
...@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16 ...@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16 ...@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskUpperTringleFromTopLeft>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......
...@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_ ...@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd, ScaleAdd,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_ ...@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskUpperTringleFromTopLeft>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......
...@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16 ...@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16 ...@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskUpperTringleFromTopLeft>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
......
...@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f ...@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTringleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
...@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f ...@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
1, 1,
1, 1,
1, 1,
MaskingSpecialization::MaskUpperTringleFromTopLeft>{}); MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
......
...@@ -241,7 +241,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -241,7 +241,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
}); });
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft && idx[1] < idx[2]) if(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
......
...@@ -31,7 +31,7 @@ template <typename ADataType, ...@@ -31,7 +31,7 @@ template <typename ADataType,
typename B0Layout, typename B0Layout,
typename B1Layout, typename B1Layout,
typename CLayout, typename CLayout,
bool MaskUpperTringleFromTopLeft> bool MaskUpperTriangleFromTopLeft>
bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -211,7 +211,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -211,7 +211,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
Acc0ElementOp, Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
MaskUpperTringleFromTopLeft>; MaskUpperTriangleFromTopLeft>;
// get device op instances // get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -230,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -230,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskUpperTringleFromTopLeft && idx[1] < idx[2]) if(MaskUpperTriangleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
......
...@@ -219,7 +219,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -219,7 +219,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskUpperTringleFromTopLeft && idx[1] < idx[2]) if(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
......
...@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskUpperTringleFromTopLeft_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>; MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
...@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskUpperTringleFromTopLeft_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>; MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
...@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128 ...@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskUpperTringleFromTopLeft>; // MaskUpperTringleFromTopLeft MaskingSpecialization::MaskUpperTriangleFromTopLeft>; // MaskUpperTriangleFromTopLeft
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
...@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128 ...@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskUpperTringleFromTopLeft>; // MaskUpperTringleFromTopLeft MaskingSpecialization::MaskUpperTriangleFromTopLeft>; // MaskUpperTriangleFromTopLeft
bool IsSupported(int M, int N, int K, int O) bool IsSupported(int M, int N, int K, int O)
{ {
......
...@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>; ...@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>;
using MaskDisabled_t = using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>; ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskUpperTringleFromTopLeft_t = using MaskUpperTriangleFromTopLeft_t =
ck::integral_constant<MaskingSpecialization, ck::integral_constant<MaskingSpecialization,
MaskingSpecialization::MaskUpperTringleFromTopLeft>; MaskingSpecialization::MaskUpperTriangleFromTopLeft>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>, std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskUpperTringleFromTopLeft_t> std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, MaskUpperTriangleFromTopLeft_t>
>; >;
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment