Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe743b79
Unverified
Commit
fe743b79
authored
Feb 09, 2025
by
youkaichao
Committed by
GitHub
Feb 09, 2025
Browse files
[bugfix] fix early import of flash attention (#12959)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
913df14d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
19 deletions
+20
-19
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+7
-6
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+3
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+6
-8
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-3
No files found.
vllm/attention/backends/flash_attn.py
View file @
fe743b79
...
@@ -14,8 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -14,8 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionType
)
from
vllm.attention.backends.utils
import
(
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
VLLM_FLASH_ATTN_VERSION
,
CommonAttentionState
,
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping
_start_idx
,
get_flash_attn_version
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
is_block_tables_empty
)
...
@@ -640,6 +640,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -640,6 +640,7 @@ class FlashAttentionImpl(AttentionImpl):
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -759,7 +760,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -759,7 +760,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
out
=
prefill_output
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
...
@@ -782,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -782,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
out
=
prefill_output
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
@@ -811,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -811,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
out
=
decode_output
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
else
:
else
:
# Use flash_attn_with_kvcache for normal decoding.
# Use flash_attn_with_kvcache for normal decoding.
...
@@ -832,7 +833,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -832,7 +833,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
return
output
return
output
...
...
vllm/attention/backends/mla/utils.py
View file @
fe743b79
...
@@ -12,7 +12,7 @@ from vllm import envs
...
@@ -12,7 +12,7 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
AttentionMetadata
,
AttentionMetadata
,
MLAAttentionImpl
,
T
)
MLAAttentionImpl
,
T
)
from
vllm.attention.backends.utils
import
VLLM_FLASH_ATTN_VERSION
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -181,6 +181,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -181,6 +181,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
if
envs
.
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
...
@@ -515,7 +516,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -515,7 +516,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k
=
max_prefill_seq_len
,
max_seqlen_k
=
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
attn_output
=
attn_output
\
attn_output
=
attn_output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
...
...
vllm/attention/backends/utils.py
View file @
fe743b79
...
@@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens(
...
@@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens(
num_decode_query_tokens
)
num_decode_query_tokens
)
try
:
def
get_flash_attn_version
():
try
:
from
vllm.vllm_flash_attn.flash_attn_interface
import
(
from
vllm.vllm_flash_attn.flash_attn_interface
import
(
fa_version_unsupported_reason
,
is_fa_version_supported
)
fa_version_unsupported_reason
,
is_fa_version_supported
)
def
flash_attn_version
():
# if hopper default to FA3, otherwise stick to FA2 for now
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
# use FA3 as default for both
...
@@ -610,7 +610,5 @@ try:
...
@@ -610,7 +610,5 @@ try:
assert
is_fa_version_supported
(
fa_version
)
assert
is_fa_version_supported
(
fa_version
)
return
fa_version
return
fa_version
except
(
ImportError
,
AssertionError
):
VLLM_FLASH_ATTN_VERSION
=
flash_attn_version
()
return
None
except
(
ImportError
,
AssertionError
):
VLLM_FLASH_ATTN_VERSION
=
None
vllm/v1/attention/backends/flash_attn.py
View file @
fe743b79
...
@@ -10,7 +10,7 @@ import triton.language as tl
...
@@ -10,7 +10,7 @@ import triton.language as tl
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
VLLM_FLASH_ATTN_VERSION
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
"FlashAttentionImpl"
)
"FlashAttentionImpl"
)
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
return
output
return
output
...
@@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment