Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c786e757
Unverified
Commit
c786e757
authored
Feb 06, 2025
by
Lucas Wilkinson
Committed by
GitHub
Feb 06, 2025
Browse files
[Attention] Use FA3 for MLA on Hopper (#12807)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
cefd56ee
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
59 deletions
+51
-59
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+11
-33
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+2
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+34
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-26
No files found.
vllm/attention/backends/flash_attn.py
View file @
c786e757
...
...
@@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadataBuilder
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping
_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_
seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_
encoder
_attn_metadata_set
,
is_
block_tables_empty
)
from
vllm.envs
import
VLLM_FLASH_ATTN_VERSION
PAD_SLOT_ID
,
VLLM_FLASH_ATTN_VERSION
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
get_
num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_
cross
_attn_metadata_set
,
is_
all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
is_fa_version_supported
)
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl):
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
self
.
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
self
.
fa_version
=
2
if
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
self
.
fa_version
=
VLLM_FLASH_ATTN_VERSION
if
not
is_fa_version_supported
(
self
.
fa_version
):
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
self
.
fa_version
,
fa_version_unsupported_reason
(
self
.
fa_version
))
assert
is_fa_version_supported
(
self
.
fa_version
)
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
@@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
fa_version
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
else
:
# prefix-enabled attention
...
...
@@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
fa_version
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
@@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
fa_version
=
self
.
fa_version
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
...
...
@@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
self
.
fa_version
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
return
output
...
...
vllm/attention/backends/mla/utils.py
View file @
c786e757
...
...
@@ -12,6 +12,7 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
,
T
)
from
vllm.attention.backends.utils
import
VLLM_FLASH_ATTN_VERSION
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k
=
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
attn_output
=
attn_output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
...
...
vllm/attention/backends/utils.py
View file @
c786e757
...
...
@@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.logger
import
logging
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
...
@@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens(
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
try
:
from
vllm.vllm_flash_attn.flash_attn_interface
import
(
fa_version_unsupported_reason
,
is_fa_version_supported
)
def
flash_attn_version
():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
fa_version
=
2
if
envs
.
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
envs
.
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
fa_version
=
envs
.
VLLM_FLASH_ATTN_VERSION
if
not
is_fa_version_supported
(
fa_version
):
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
fa_version
,
fa_version_unsupported_reason
(
fa_version
))
assert
is_fa_version_supported
(
fa_version
)
return
fa_version
VLLM_FLASH_ATTN_VERSION
=
flash_attn_version
()
except
ImportError
:
VLLM_FLASH_ATTN_VERSION
=
None
vllm/v1/attention/backends/flash_attn.py
View file @
c786e757
...
...
@@ -10,13 +10,10 @@ import triton.language as tl
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.
env
s
import
VLLM_FLASH_ATTN_VERSION
from
vllm.
attention.backends.util
s
import
VLLM_FLASH_ATTN_VERSION
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
is_fa_version_supported
)
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
logger
=
init_logger
(
__name__
)
...
...
@@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl"
)
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
self
.
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
self
.
fa_version
=
2
if
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
self
.
fa_version
=
VLLM_FLASH_ATTN_VERSION
if
not
is_fa_version_supported
(
self
.
fa_version
):
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
self
.
fa_version
,
fa_version_unsupported_reason
(
self
.
fa_version
))
assert
is_fa_version_supported
(
self
.
fa_version
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
fa_version
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
return
output
...
...
@@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
fa_version
,
fa_version
=
VLLM_FLASH_ATTN_VERSION
,
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment