Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a190df01
Commit
a190df01
authored
Feb 04, 2024
by
Tri Dao
Browse files
Add window_size option to ParallelMHA
parent
2423cca3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+5
-2
No files found.
flash_attn/modules/mha.py
View file @
a190df01
...
...
@@ -747,6 +747,7 @@ class ParallelMHA(nn.Module):
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
window_size
=
(
-
1
,
-
1
),
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
...
...
@@ -793,6 +794,8 @@ class ParallelMHA(nn.Module):
)
else
:
alibi_slopes
=
None
if
window_size
!=
(
-
1
,
-
1
):
assert
use_flash_attn
,
"Local (sliding window) attention code path requires flash_attn"
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
...
...
@@ -816,12 +819,12 @@ class ParallelMHA(nn.Module):
**
factory_kwargs
,
)
inner_attn_cls
=
(
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
)
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
)
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
CrossAttention
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment