Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
fdf6d72b
Unverified
Commit
fdf6d72b
authored
Nov 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Nov 28, 2024
Browse files
Clean up API & Bypass torch.autograd.Function (#30)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7d3409be
Pipeline
#2013
failed with stages
in 0 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
920 deletions
+27
-920
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+0
-4
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+27
-916
No files found.
vllm_flash_attn/__init__.py
View file @
fdf6d72b
...
...
@@ -3,10 +3,6 @@ __version__ = "2.6.2"
# Use relative import to support build-from-source installation in vLLM
from
.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
vllm_flash_attn/flash_attn_interface.py
View file @
fdf6d72b
...
...
@@ -65,9 +65,7 @@ def _flash_attn_forward(
return_softmax
,
None
,
)
# NOTE(woosuk): out_padded, S_dmask, and rng_state are None
# because we only use the forward pass in the vLLM.
return
out
,
q
,
k
,
v
,
out
,
softmax_lse
,
None
,
None
return
out
,
softmax_lse
def
_flash_attn_varlen_forward
(
...
...
@@ -112,732 +110,7 @@ def _flash_attn_varlen_forward(
return_softmax
,
None
,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
# NOTE(woosuk): out_padded, S_dmask, and rng_state are None
# because we only use the forward pass in the vLLM.
return
out
,
q
,
k
,
v
,
None
,
softmax_lse
,
None
,
None
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
rng_state
=
None
,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
(
dq
,
dk
,
dv
,
softmax_d
,
)
=
torch
.
ops
.
vllm_flash_attn_c
.
bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
alibi_slopes
,
dropout_p
,
softmax_scale
,
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
deterministic
,
None
,
rng_state
,
)
return
dq
,
dk
,
dv
,
softmax_d
def
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
rng_state
=
None
,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
(
dq
,
dk
,
dv
,
softmax_d
,
)
=
torch
.
ops
.
vllm_flash_attn_c
.
varlen_bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
alibi_slopes
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
window_size
[
0
],
window_size
[
1
],
softcap
,
deterministic
,
None
,
rng_state
,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dq
,
dk
,
dv
,
softmax_d
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
*
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
rng_state
=
ctx
.
saved_tensors
qkv_shape
=
q
.
shape
[:
-
2
]
+
(
3
,
*
q
.
shape
[
-
2
:])
dqkv
=
torch
.
empty
(
qkv_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
*
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen
=
max_seqlen
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
qkv_shape
=
q
.
shape
[:
-
2
]
+
(
3
,
*
q
.
shape
[
-
2
:])
dqkv
=
torch
.
empty
(
qkv_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
rng_state
=
ctx
.
saved_tensors
dq
=
torch
.
empty_like
(
q
)
kv_shape
=
k
.
shape
[:
-
2
]
+
(
2
,
*
k
.
shape
[
-
2
:])
dkv
=
torch
.
empty
(
kv_shape
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
q
,
kv
[:,
0
],
kv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
None
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
dq
=
torch
.
empty_like
(
q
)
kv_shape
=
k
.
shape
[:
-
2
]
+
(
2
,
*
k
.
shape
[
-
2
:])
dkv
=
torch
.
empty
(
kv_shape
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
rng_state
=
ctx
.
saved_tensors
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
block_table
,
out
=
None
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
block_table
=
block_table
,
out
=
out
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
window_size
=
window_size
ctx
.
softcap
=
softcap
ctx
.
alibi_slopes
=
alibi_slopes
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
def
backward
(
ctx
,
dout
,
*
args
):
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
ctx
.
window_size
,
ctx
.
softcap
,
ctx
.
alibi_slopes
,
ctx
.
deterministic
,
rng_state
=
rng_state
,
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# <=0.0 means deactivate
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
def
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
return
out
,
softmax_lse
def
flash_attn_func
(
...
...
@@ -853,6 +126,7 @@ def flash_attn_func(
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
return_softmax_lse
=
False
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
...
...
@@ -896,189 +170,26 @@ def flash_attn_func(
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_
attn_probs
=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_
softmax_lse
=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnFunc
.
apply
(
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
def
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnVarlenQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
)
def
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
softcap
=
0.0
,
# 0.0 means deactivated
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
*
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnVarlenKVPackedFunc
.
apply
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
out
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_attn_probs
and
dropout_p
>
0
,
out
=
out
,
)
return
(
out
,
softmax_lse
)
if
return_softmax_lse
else
out
def
flash_attn_varlen_func
(
...
...
@@ -1099,6 +210,7 @@ def flash_attn_varlen_func(
return_attn_probs
=
False
,
block_table
=
None
,
*
,
return_softmax_lse
=
False
,
out
=
None
,
):
"""dropout_p should be set to 0.0 during evaluation
...
...
@@ -1149,14 +261,13 @@ def flash_attn_varlen_func(
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_
attn_probs
=True]: (nheads, total_q_seqlen). The
softmax_lse [optional, if return_
softmax_lse
=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnVarlenFunc
.
apply
(
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
_flash_attn_varlen_forward
(
q
,
k
,
v
,
...
...
@@ -1166,15 +277,15 @@ def flash_attn_varlen_func(
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
block_table
,
out
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_attn_probs
and
dropout_p
>
0
,
block_table
=
block_table
,
out
=
out
,
)
return
(
out
,
softmax_lse
)
if
return_softmax_lse
else
out
def
flash_attn_with_kvcache
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment