Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
321b6c8e
Commit
321b6c8e
authored
Jul 25, 2023
by
ltqin
Browse files
change enum to MaskUpperTringleFrom
parent
b4514459
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
88 additions
and
49 deletions
+88
-49
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
..._softmax_gemm/batched_multihead_attention_backward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+9
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
321b6c8e
...
...
@@ -59,7 +59,7 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
View file @
321b6c8e
...
...
@@ -85,7 +85,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
View file @
321b6c8e
...
...
@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
321b6c8e
...
...
@@ -94,7 +94,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
321b6c8e
...
...
@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
View file @
321b6c8e
...
...
@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
321b6c8e
...
...
@@ -84,7 +84,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
View file @
321b6c8e
...
...
@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
321b6c8e
...
...
@@ -93,7 +93,7 @@ static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
321b6c8e
...
...
@@ -319,9 +319,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -439,7 +443,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
D0sTransferSrcScalarPerVector
>
;
// Argument
...
...
@@ -503,7 +507,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c0de_element_op_
{
c0de_element_op
},
b1_element_op_
{
b1_element_op
},
c1de_element_op_
{
c1de_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
321b6c8e
...
...
@@ -364,7 +364,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0MatrixMask
=
conditional_t
<
MaskOutUpperTriangle
,
C0MatrixMask_impl
<
Mask
Out
UpperTri
a
nglePredicate
>
,
C0MatrixMask_impl
<
MaskUpperTringle
FromTopLeft
Predicate
>
,
C0MatrixMask_impl
<
MaskDisabledPredicate
>>
;
// GridwiseGemm
...
...
@@ -473,7 +473,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
},
c0_matrix_mask_
{
NRaw
},
c0_matrix_mask_
{
MRaw
,
NRaw
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
321b6c8e
...
...
@@ -564,9 +564,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -687,7 +691,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// Argument
...
...
@@ -772,7 +776,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
321b6c8e
...
...
@@ -570,9 +570,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -701,7 +705,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// Argument
...
...
@@ -785,7 +789,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
321b6c8e
...
...
@@ -559,9 +559,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -683,7 +687,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// Argument
...
...
@@ -768,7 +772,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
321b6c8e
...
...
@@ -565,9 +565,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
Mask
Out
UpperTri
a
nglePredicate
{};
return
MaskUpperTringle
FromTopLeft
Predicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v1.hpp
View file @
321b6c8e
...
...
@@ -386,9 +386,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -515,7 +519,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// Argument
...
...
@@ -588,7 +592,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
321b6c8e
...
...
@@ -394,9 +394,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -523,7 +527,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
// Argument
...
...
@@ -596,7 +600,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
c0_matrix_mask_
{
a_grid_desc_g_m_k_
.
GetLength
(
I1
),
b_grid_desc_g_n_k_
.
GetLength
(
I1
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
321b6c8e
...
...
@@ -291,9 +291,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -399,7 +403,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
>
;
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -527,7 +531,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
);
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
321b6c8e
...
...
@@ -500,9 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -622,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -818,7 +822,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
321b6c8e
...
...
@@ -500,9 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -630,7 +634,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -826,7 +830,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment