Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a2974e85
Commit
a2974e85
authored
Aug 26, 2023
by
Tri Dao
Browse files
Change causal for CrossAttention in mha.py to align to bottom right
parent
9b713872
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
22 deletions
+15
-22
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+15
-22
No files found.
flash_attn/modules/mha.py
View file @
a2974e85
...
...
@@ -271,13 +271,18 @@ class CrossAttention(nn.Module):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
"b s -> b 1 1 s"
)
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen_q
,
seqlen_k
),
-
10000.0
,
device
=
scores
.
device
),
1
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
q
.
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
kv
.
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
causal_mask
=
col_idx
>
row_idx
+
sk
-
seqlen_q
scores
=
scores
.
masked_fill
(
causal_mask
,
-
10000.0
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
...
...
@@ -627,10 +632,7 @@ class MHA(nn.Module):
else
:
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
...
...
@@ -677,10 +679,7 @@ class MHA(nn.Module):
)
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
out
=
self
.
out_proj
(
rearrange
(
context
,
"... h d -> ... (h d)"
))
...
...
@@ -869,10 +868,7 @@ class ParallelMHA(nn.Module):
else
:
q
=
qkv
[:,
:,
0
]
kv
=
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
,
self
.
layer_idx
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
...
...
@@ -903,10 +899,7 @@ class ParallelMHA(nn.Module):
)
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
)
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
rearrange
(
context
,
"b s h d -> b s (h d)"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment