Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8afad0f6
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "6eef0755c923c8dbf7860e31fb8c2e1b8859bc6e"
Commit
8afad0f6
authored
Dec 30, 2022
by
ltqin
Browse files
open mask
parent
ac3c1563
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
7 deletions
+14
-7
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+14
-7
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
8afad0f6
...
...
@@ -24,6 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK 1
#include <iostream>
#include <numeric>
...
...
@@ -69,8 +70,13 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskOutUpperTriangle
;
#else
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskOutUpperTriangle
;
#endif
static
constexpr
auto
TensorSpecQ
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
...
...
@@ -203,12 +209,13 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if 0
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#if USING_MASK
auto
N
=
s_g_m_n
.
GetLengths
()[
1
];
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
s_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
#endif
// P = Softmax(S)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment