Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
13 deletions
+79
-13
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+79
-13
No files found.
vllm_flash_attn/flash_attn_interface.py
View file @
d562aa63
...
@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
...
@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
...
@@ -60,6 +60,7 @@ def _flash_attn_forward(
...
@@ -60,6 +60,7 @@ def _flash_attn_forward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
softcap
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -78,6 +79,7 @@ def _flash_attn_varlen_forward(
...
@@ -78,6 +79,7 @@ def _flash_attn_varlen_forward(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
return_softmax
,
return_softmax
,
block_table
,
block_table
,
...
@@ -103,6 +105,7 @@ def _flash_attn_varlen_forward(
...
@@ -103,6 +105,7 @@ def _flash_attn_varlen_forward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
softcap
,
return_softmax
,
return_softmax
,
None
,
None
,
)
)
...
@@ -125,13 +128,19 @@ def _flash_attn_backward(
...
@@ -125,13 +128,19 @@ def _flash_attn_backward(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
rng_state
=
None
,
rng_state
=
None
,
):
):
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
bwd
(
(
dq
,
dk
,
dv
,
softmax_d
,
)
=
flash_attn_cuda
.
bwd
(
dout
,
dout
,
q
,
q
,
k
,
k
,
...
@@ -147,6 +156,7 @@ def _flash_attn_backward(
...
@@ -147,6 +156,7 @@ def _flash_attn_backward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
softcap
,
deterministic
,
deterministic
,
None
,
None
,
rng_state
,
rng_state
,
...
@@ -172,13 +182,19 @@ def _flash_attn_varlen_backward(
...
@@ -172,13 +182,19 @@ def _flash_attn_varlen_backward(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
rng_state
=
None
,
rng_state
=
None
,
):
):
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
varlen_bwd
(
(
dq
,
dk
,
dv
,
softmax_d
,
)
=
flash_attn_cuda
.
varlen_bwd
(
dout
,
dout
,
q
,
q
,
k
,
k
,
...
@@ -199,6 +215,7 @@ def _flash_attn_varlen_backward(
...
@@ -199,6 +215,7 @@ def _flash_attn_varlen_backward(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
softcap
,
deterministic
,
deterministic
,
None
,
None
,
rng_state
,
rng_state
,
...
@@ -217,6 +234,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -217,6 +234,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -233,6 +251,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -233,6 +251,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
out
=
out
,
...
@@ -242,6 +261,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -242,6 +261,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -265,12 +285,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -265,12 +285,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
...
@@ -284,6 +305,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -284,6 +305,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -304,6 +326,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -304,6 +326,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
block_table
=
None
,
...
@@ -315,6 +338,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -315,6 +338,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -342,12 +366,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -342,12 +366,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
...
@@ -360,6 +385,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -360,6 +385,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -375,6 +401,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -375,6 +401,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
out
=
out
,
...
@@ -384,6 +411,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -384,6 +411,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -408,13 +436,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -408,13 +436,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
...
@@ -431,6 +460,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -431,6 +460,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -450,6 +480,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -450,6 +480,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
block_table
=
None
,
...
@@ -464,6 +495,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -464,6 +495,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -492,13 +524,14 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -492,13 +524,14 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
...
@@ -512,6 +545,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -512,6 +545,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -527,6 +561,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -527,6 +561,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
out
=
out
,
...
@@ -536,6 +571,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -536,6 +571,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -558,6 +594,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -558,6 +594,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
...
@@ -565,7 +602,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -565,7 +602,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
...
@@ -583,6 +620,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -583,6 +620,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_softmax
,
return_softmax
,
...
@@ -603,6 +641,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -603,6 +641,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale
,
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
block_table
=
block_table
,
...
@@ -617,6 +656,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -617,6 +656,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
...
@@ -643,6 +683,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -643,6 +683,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx
.
softmax_scale
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
rng_state
=
rng_state
,
...
@@ -650,7 +691,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -650,7 +691,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
def
flash_attn_qkvpacked_func
(
...
@@ -659,6 +700,7 @@ def flash_attn_qkvpacked_func(
...
@@ -659,6 +700,7 @@ def flash_attn_qkvpacked_func(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# <=0.0 means deactivate
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
...
@@ -682,6 +724,7 @@ def flash_attn_qkvpacked_func(
...
@@ -682,6 +724,7 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...
@@ -704,6 +747,7 @@ def flash_attn_qkvpacked_func(
...
@@ -704,6 +747,7 @@ def flash_attn_qkvpacked_func(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
...
@@ -718,6 +762,7 @@ def flash_attn_kvpacked_func(
...
@@ -718,6 +762,7 @@ def flash_attn_kvpacked_func(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
...
@@ -757,6 +802,7 @@ def flash_attn_kvpacked_func(
...
@@ -757,6 +802,7 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
is added to the attention score of query i and key j.
...
@@ -781,6 +827,7 @@ def flash_attn_kvpacked_func(
...
@@ -781,6 +827,7 @@ def flash_attn_kvpacked_func(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
...
@@ -796,6 +843,7 @@ def flash_attn_func(
...
@@ -796,6 +843,7 @@ def flash_attn_func(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
...
@@ -858,6 +906,7 @@ def flash_attn_func(
...
@@ -858,6 +906,7 @@ def flash_attn_func(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
...
@@ -873,6 +922,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -873,6 +922,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
...
@@ -899,6 +949,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -899,6 +949,7 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
...
@@ -908,7 +959,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -908,7 +959,7 @@ def flash_attn_varlen_qkvpacked_func(
(they might not have the right scaling).
(they might not have the right scaling).
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (
batch_size,
nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads,
total_q_
seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...
@@ -923,6 +974,7 @@ def flash_attn_varlen_qkvpacked_func(
...
@@ -923,6 +974,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
...
@@ -941,6 +993,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -941,6 +993,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
...
@@ -986,6 +1039,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -986,6 +1039,7 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
is added to the attention score of query i and key j.
...
@@ -996,7 +1050,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -996,7 +1050,7 @@ def flash_attn_varlen_kvpacked_func(
(they might not have the right scaling).
(they might not have the right scaling).
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (
batch_size,
nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads,
total_q_
seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...
@@ -1014,6 +1068,7 @@ def flash_attn_varlen_kvpacked_func(
...
@@ -1014,6 +1068,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
...
@@ -1033,6 +1088,7 @@ def flash_attn_varlen_func(
...
@@ -1033,6 +1088,7 @@ def flash_attn_varlen_func(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
alibi_slopes
=
None
,
deterministic
=
False
,
deterministic
=
False
,
return_attn_probs
=
False
,
return_attn_probs
=
False
,
...
@@ -1077,6 +1133,7 @@ def flash_attn_varlen_func(
...
@@ -1077,6 +1133,7 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
is added to the attention score of query i and key j.
...
@@ -1087,7 +1144,7 @@ def flash_attn_varlen_func(
...
@@ -1087,7 +1144,7 @@ def flash_attn_varlen_func(
(they might not have the right scaling).
(they might not have the right scaling).
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (
batch_size,
nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (nheads,
total_q_
seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
...
@@ -1106,6 +1163,7 @@ def flash_attn_varlen_func(
...
@@ -1106,6 +1163,7 @@ def flash_attn_varlen_func(
softmax_scale
,
softmax_scale
,
causal
,
causal
,
window_size
,
window_size
,
softcap
,
alibi_slopes
,
alibi_slopes
,
deterministic
,
deterministic
,
return_attn_probs
,
return_attn_probs
,
...
@@ -1128,9 +1186,11 @@ def flash_attn_with_kvcache(
...
@@ -1128,9 +1186,11 @@ def flash_attn_with_kvcache(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
rotary_interleaved
=
True
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
num_splits
=
0
,
num_splits
=
0
,
return_softmax_lse
=
False
,
*
,
*
,
out
=
None
,
out
=
None
,
):
):
...
@@ -1200,6 +1260,7 @@ def flash_attn_with_kvcache(
...
@@ -1200,6 +1260,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...
@@ -1211,9 +1272,13 @@ def flash_attn_with_kvcache(
...
@@ -1211,9 +1272,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
Return:
out: (batch_size, seqlen, nheads, headdim).
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
"""
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
...
@@ -1244,7 +1309,8 @@ def flash_attn_with_kvcache(
...
@@ -1244,7 +1309,8 @@ def flash_attn_with_kvcache(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
softcap
,
rotary_interleaved
,
rotary_interleaved
,
num_splits
,
num_splits
,
)
)
return
out
return
(
out
,
softmax_lse
)
if
return_softmax_lse
else
out
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment