Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ef0ed106
Commit
ef0ed106
authored
Jan 31, 2024
by
Tri Dao
Browse files
Add window_size option to MHA and GPT
parent
dc72d960
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
4 deletions
+31
-4
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+29
-4
No files found.
flash_attn/models/gpt.py
View file @
ef0ed106
...
...
@@ -78,6 +78,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_scale_base
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_emb_interleaved
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
use_alibi
=
getattr
(
config
,
"use_alibi"
,
False
)
window_size
=
getattr
(
config
,
"window_size"
,
(
-
1
,
-
1
))
use_flash_attn
=
getattr
(
config
,
"use_flash_attn"
,
False
)
fused_bias_fc
=
getattr
(
config
,
"fused_bias_fc"
,
False
)
if
not
fused_bias_fc
:
...
...
@@ -110,6 +111,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_interleaved
=
rotary_emb_interleaved
,
use_alibi
=
use_alibi
,
window_size
=
window_size
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
parallel_kwargs
,
...
...
flash_attn/modules/mha.py
View file @
ef0ed106
...
...
@@ -61,7 +61,15 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
window_size
=
(
-
1
,
-
1
),
alibi_slopes
=
None
,
deterministic
=
False
,
):
super
().
__init__
()
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
...
...
@@ -69,6 +77,7 @@ class FlashSelfAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
window_size
=
window_size
self
.
deterministic
=
deterministic
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
...
...
@@ -104,6 +113,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
else
:
...
...
@@ -113,6 +123,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
...
...
@@ -128,7 +139,15 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
window_size
=
(
-
1
,
-
1
),
deterministic
=
False
,
):
super
().
__init__
()
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
...
...
@@ -136,6 +155,7 @@ class FlashCrossAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
window_size
=
window_size
self
.
deterministic
=
deterministic
def
forward
(
...
...
@@ -184,6 +204,7 @@ class FlashCrossAttention(nn.Module):
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
else
:
...
...
@@ -197,6 +218,7 @@ class FlashCrossAttention(nn.Module):
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
window_size
,
deterministic
=
self
.
deterministic
,
)
...
...
@@ -372,6 +394,7 @@ class MHA(nn.Module):
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_alibi
=
False
,
window_size
=
(
-
1
,
-
1
),
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
...
...
@@ -401,6 +424,8 @@ class MHA(nn.Module):
alibi_slopes
=
torch
.
tensor
(
get_alibi_slopes
(
num_heads
),
device
=
device
)
else
:
alibi_slopes
=
None
if
window_size
!=
(
-
1
,
-
1
):
assert
use_flash_attn
,
"Local (sliding window) attention code path requires flash_attn"
self
.
num_heads
=
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
...
...
@@ -431,12 +456,12 @@ class MHA(nn.Module):
)
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
inner_attn_cls
=
(
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
)
partial
(
FlashSelfAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
SelfAttention
)
inner_cross_attn_cls
=
(
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
)
partial
(
FlashCrossAttention
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
)
if
use_flash_attn
else
CrossAttention
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment