Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8f48a546
Unverified
Commit
8f48a546
authored
Jul 24, 2024
by
youkaichao
Committed by
GitHub
Jul 24, 2024
Browse files
use global function rather than lambda (#7)
parent
537f75eb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
5 deletions
+2
-5
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+2
-5
No files found.
vllm_flash_attn/flash_attn_interface.py
View file @
8f48a546
...
@@ -11,6 +11,8 @@ import vllm_flash_attn_2_cuda as flash_attn_cuda
...
@@ -11,6 +11,8 @@ import vllm_flash_attn_2_cuda as flash_attn_cuda
# isort: on
# isort: on
def
maybe_contiguous
(
x
):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
_get_block_size_n
(
device
,
head_dim
,
is_dropout
,
is_causal
):
def
_get_block_size_n
(
device
,
head_dim
,
is_dropout
,
is_causal
):
# This should match the block sizes in the CUDA kernel
# This should match the block sizes in the CUDA kernel
...
@@ -46,7 +48,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
...
@@ -46,7 +48,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def
_flash_attn_forward
(
def
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
fwd
(
q
,
q
,
...
@@ -83,7 +84,6 @@ def _flash_attn_varlen_forward(
...
@@ -83,7 +84,6 @@ def _flash_attn_varlen_forward(
*
,
*
,
out
=
None
out
=
None
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
varlen_fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
varlen_fwd
(
q
,
q
,
...
@@ -129,7 +129,6 @@ def _flash_attn_backward(
...
@@ -129,7 +129,6 @@ def _flash_attn_backward(
deterministic
,
deterministic
,
rng_state
=
None
,
rng_state
=
None
,
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
bwd
(
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
bwd
(
...
@@ -177,7 +176,6 @@ def _flash_attn_varlen_backward(
...
@@ -177,7 +176,6 @@ def _flash_attn_varlen_backward(
deterministic
,
deterministic
,
rng_state
=
None
,
rng_state
=
None
,
):
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
varlen_bwd
(
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
varlen_bwd
(
...
@@ -1219,7 +1217,6 @@ def flash_attn_with_kvcache(
...
@@ -1219,7 +1217,6 @@ def flash_attn_with_kvcache(
"""
"""
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment