Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5af78ac2
Commit
5af78ac2
authored
Jul 26, 2023
by
ltqin
Browse files
fix triagle name
parent
4a653a5d
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
47 additions
and
47 deletions
+47
-47
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
..._softmax_gemm/batched_multihead_attention_backward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+4
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
5af78ac2
...
@@ -59,7 +59,7 @@ using CElementOp = PassThrough;
...
@@ -59,7 +59,7 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
View file @
5af78ac2
...
@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
5af78ac2
...
@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
View file @
5af78ac2
...
@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
5af78ac2
...
@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
5af78ac2
...
@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
...
@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
View file @
5af78ac2
...
@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
5af78ac2
...
@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
View file @
5af78ac2
...
@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
5af78ac2
...
@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
...
@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
;
#else
#else
static
constexpr
auto
MaskingSpec
=
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
5af78ac2
...
@@ -319,13 +319,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -319,13 +319,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
5af78ac2
...
@@ -364,7 +364,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -364,7 +364,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0MatrixMask
=
conditional_t
<
MaskOutUpperTriangle
,
using
C0MatrixMask
=
conditional_t
<
MaskOutUpperTriangle
,
C0MatrixMask_impl
<
MaskUpperTringleFromTopLeftPredicate
>
,
C0MatrixMask_impl
<
MaskUpperTri
a
ngleFromTopLeftPredicate
>
,
C0MatrixMask_impl
<
MaskDisabledPredicate
>>
;
C0MatrixMask_impl
<
MaskDisabledPredicate
>>
;
// GridwiseGemm
// GridwiseGemm
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
5af78ac2
...
@@ -564,13 +564,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -564,13 +564,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
5af78ac2
...
@@ -570,13 +570,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
...
@@ -570,13 +570,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
5af78ac2
...
@@ -559,13 +559,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -559,13 +559,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
5af78ac2
...
@@ -565,13 +565,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -565,13 +565,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
View file @
5af78ac2
...
@@ -386,13 +386,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
...
@@ -386,13 +386,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
5af78ac2
...
@@ -394,13 +394,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -394,13 +394,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
5af78ac2
...
@@ -291,13 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -291,13 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
5af78ac2
...
@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{
{
return
MaskDisabledPredicate
{};
return
MaskDisabledPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment