Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82f60bef
Commit
82f60bef
authored
May 28, 2025
by
zhuwenwen
Browse files
set VLLM_FLASH_ATTN_BACKEND to use FlashAttention Backend for attention computation on rocm
parent
07b41ddf
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
380 additions
and
195 deletions
+380
-195
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+154
-79
vllm/envs.py
vllm/envs.py
+14
-9
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+94
-13
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+118
-94
No files found.
vllm/attention/backends/flash_attn.py
View file @
82f60bef
...
...
@@ -27,7 +27,12 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_rocm
():
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
else
:
from
flash_attn
import
(
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
...
...
@@ -807,6 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
(
num_kv_tokens
,
num_kv_heads
,
head_size
))
descale_shape
=
(
q_seq_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
...
...
@@ -826,6 +832,21 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
else
:
# prefix-enabled attention
assert
attn_type
==
AttentionType
.
DECODER
,
(
...
...
@@ -835,6 +856,7 @@ class FlashAttentionImpl(AttentionImpl):
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
descale_shape
=
(
prefill_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
...
...
@@ -855,6 +877,27 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
vllm_flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
...
...
@@ -870,6 +913,7 @@ class FlashAttentionImpl(AttentionImpl):
assert
decode_meta
.
query_start_loc
is
not
None
descale_shape
=
(
decode_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
...
...
@@ -890,6 +934,22 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
decode_output
=
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
seqused_k
=
decode_meta
.
seq_lens_tensor
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
(
...
...
@@ -898,6 +958,7 @@ class FlashAttentionImpl(AttentionImpl):
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
descale_shape
=
(
seq_lens_arg
.
shape
[
0
],
key_cache
.
shape
[
-
2
])
if
not
current_platform
.
is_rocm
():
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
...
...
@@ -915,6 +976,20 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
decode_output
=
decode_output
.
unsqueeze
(
1
)
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
return
output
...
...
vllm/envs.py
View file @
82f60bef
...
...
@@ -124,7 +124,7 @@ if TYPE_CHECKING:
VLLM_PCIE_USE_CUSTOM_ALLREDUCE
:
bool
=
False
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_HAS_CONTEXT_DEFAULT
:
bool
=
False
VLLM_FLASH_ATTN_BACKEND
:
bool
=
False
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_ZERO_OVERHEAD
:
bool
=
False
...
...
@@ -799,14 +799,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
,
"-1"
)),
# Enable two batch overlap.
"VLLM_ENABLE_TBO"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_TBO"
,
"0"
))),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZERO_OVERHEAD"
,
"0"
))),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
...
...
@@ -814,6 +806,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_HAS_CONTEXT_DEFAULT"
,
"0"
))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_BACKEND"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# Enable two batch overlap.
"VLLM_ENABLE_TBO"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_TBO"
,
"0"
))),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZERO_OVERHEAD"
,
"0"
))),
}
# end-env-vars-definition
...
...
vllm/platforms/rocm.py
View file @
82f60bef
...
...
@@ -205,6 +205,87 @@ class RocmPlatform(Platform):
# f" The selected backend, {selected_backend.name},"
# f"is not MLA type while requested for MLA backend.")
if
envs
.
VLLM_FLASH_ATTN_BACKEND
:
if
use_v1
:
if
selected_backend
==
_Backend
.
FLASHINFER
:
raise
ValueError
(
"FlashInfer backend on V1 engine is not supported"
)
if
selected_backend
==
_Backend
.
TRITON_ATTN_VLLM_V1
:
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend"
)
if
cls
.
has_device_capability
(
80
):
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend"
)
if
selected_backend
==
_Backend
.
FLASHINFER
:
raise
ValueError
(
"FlashInfer backend is not supported"
)
elif
selected_backend
==
_Backend
.
XFORMERS
:
raise
ValueError
(
"XFormers backend is not supported"
)
elif
selected_backend
==
_Backend
.
FLASH_ATTN
:
pass
elif
selected_backend
:
raise
ValueError
(
f
"Invalid attention backend for
{
cls
.
device_name
}
, "
f
"with use_v1:
{
use_v1
}
use_mla:
{
use_mla
}
"
)
target_backend
=
_Backend
.
FLASH_ATTN
if
not
cls
.
has_device_capability
(
80
):
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs."
)
raise
ValueError
(
"XFormers backend is not supported"
)
elif
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info
(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
raise
ValueError
(
"XFormers backend is not supported"
)
elif
block_size
%
16
!=
0
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16."
)
raise
ValueError
(
"XFormers backend is not supported"
)
# FlashAttn is valid for the model, checking if the package is
# installed.
if
target_backend
==
_Backend
.
FLASH_ATTN
:
try
:
import
flash_attn
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
,
flash_attn_supports_fp8
)
supported_sizes
=
\
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
supported_sizes
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for head size %d."
,
head_size
)
raise
ValueError
(
"XFormers backend is not supported"
)
fp8_kv_cache
=
(
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
))
if
(
fp8_kv_cache
and
not
flash_attn_supports_fp8
()):
logger
.
info
(
"Cannot use FlashAttention backend for FP8 KV cache."
)
logger
.
warning
(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER"
)
raise
ValueError
(
"XFormers backend is not supported"
)
except
ImportError
:
logger
.
info
(
"Cannot use FlashAttention-2 backend because the "
"flash_attn package is not found. "
"Make sure that flash_attn was built and installed "
"(on by default)."
)
raise
ValueError
(
"XFormers backend is not supported"
)
if
target_backend
==
_Backend
.
XFORMERS
:
raise
ValueError
(
"XFormers backend is not supported"
)
logger
.
info
(
"Using Flash Attention backend."
)
return
"vllm.attention.backends.flash_attn.FlashAttentionBackend"
else
:
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
envs
.
VLLM_USE_V1
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
82f60bef
...
...
@@ -24,11 +24,11 @@ if TYPE_CHECKING:
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_
cuda
():
if
not
current_platform
.
is_
rocm
():
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
get_scheduler_metadata
)
else
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
logger
=
init_logger
(
__name__
)
...
...
@@ -605,6 +605,7 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
...
...
@@ -626,11 +627,30 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
# scheduler_metadata=scheduler_metadata,
)
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
if
not
current_platform
.
is_rocm
():
cascade_attention
(
output
[:
num_actual_tokens
],
query
[:
num_actual_tokens
],
...
...
@@ -656,6 +676,8 @@ class FlashAttentionImpl(AttentionImpl):
v_descale
=
layer
.
_v_scale
,
)
return
output
else
:
raise
ValueError
(
"cascade attention is not supported on rocm"
)
def
use_cascade_attention
(
...
...
@@ -763,6 +785,7 @@ def cascade_attention(
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process shared prefix.
if
not
current_platform
.
is_rocm
():
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
...
...
@@ -790,6 +813,7 @@ def cascade_attention(
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process suffix per query.
if
not
current_platform
.
is_rocm
():
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment