Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b233584a
Commit
b233584a
authored
Apr 03, 2026
by
laibao
Committed by
zhangzbb
Apr 03, 2026
Browse files
[BUGFIX] 回退 ROCm FlashAttention unified KV layout 改动并修正 unified kernel 选择逻辑
parent
2888b4e5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
143 additions
and
286 deletions
+143
-286
vllm/attention/layer.py
vllm/attention/layer.py
+1
-2
vllm/envs.py
vllm/envs.py
+0
-5
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+3
-31
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+0
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+49
-184
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+89
-54
vllm/v1/attention/selector.py
vllm/v1/attention/selector.py
+0
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-3
No files found.
vllm/attention/layer.py
View file @
b233584a
...
...
@@ -245,7 +245,6 @@ class Attention(nn.Module, AttentionLayerBase):
use_mla
=
False
,
has_sink
=
self
.
has_sink
,
use_mm_prefix
=
self
.
use_mm_prefix
,
use_alibi_sqrt
=
bool
(
use_alibi_sqrt
),
attn_type
=
attn_type
,
)
else
:
...
...
vllm/envs.py
View file @
b233584a
...
...
@@ -1989,11 +1989,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RAY_ASYNC_SCHEDULING"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ENABLE_RAY_ASYNC_SCHEDULING"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8"
:
lambda
:
(
os
.
environ
.
get
(
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
...
...
vllm/platforms/rocm.py
View file @
b233584a
...
...
@@ -262,7 +262,6 @@ class RocmPlatform(Platform):
from
vllm._aiter_ops
import
rocm_aiter_ops
block_size
=
attn_selector_config
.
block_size
head_size
=
attn_selector_config
.
head_size
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
attn_selector_config
.
use_sparse
:
...
...
@@ -305,36 +304,9 @@ class RocmPlatform(Platform):
f
"is not MLA type while requested for MLA backend."
)
is_non64_block_multiple_64
=
(
block_size
!=
64
and
block_size
%
64
==
0
)
use_unified_flash
=
(
is_non64_block_multiple_64
and
head_size
==
256
)
if
(
envs
.
VLLM_USE_FLASH_ATTN_PA
and
is_non64_block_multiple_64
and
head_size
!=
256
):
logger
.
info_once
(
"Skip unified varlen kernel on V1 engine: head size %d is "
"unsupported (requires 256)."
,
head_size
,
)
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
(
block_size
==
64
or
use_unified_flash
):
if
use_unified_flash
and
block_size
!=
64
:
logger
.
info_once
(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)"
,
block_size
,
)
else
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
block_size
==
64
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine. (only supports block size 64)"
)
return
AttentionBackendEnum
.
FLASH_ATTN
.
get_path
()
else
:
os
.
environ
[
'VLLM_USE_FLASH_ATTN_PA'
]
=
'0'
...
...
vllm/v1/attention/backend.py
View file @
b233584a
...
...
@@ -225,7 +225,6 @@ class AttentionBackend(ABC):
has_sink
:
bool
,
use_sparse
:
bool
,
use_mm_prefix
:
bool
,
use_alibi_sqrt
:
bool
,
device_capability
:
"DeviceCapability"
,
attn_type
:
str
,
)
->
list
[
str
]:
...
...
@@ -242,8 +241,6 @@ class AttentionBackend(ABC):
invalid_reasons
.
append
(
"partial multimodal token full attention not supported"
)
if
use_alibi_sqrt
and
not
cls
.
supports_alibi_sqrt
():
invalid_reasons
.
append
(
"use_alibi_sqrt not supported"
)
if
use_mla
!=
cls
.
is_mla
():
if
use_mla
:
invalid_reasons
.
append
(
"MLA not supported"
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
b233584a
...
...
@@ -33,13 +33,6 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
,
)
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
triton_reshape_and_cache_flash
,
)
try
:
from
flash_attn
import
varlen_fwd_unified
except
Exception
:
varlen_fwd_unified
=
None
else
:
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_sinks
,
...
...
@@ -120,38 +113,6 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
classmethod
def
supports_alibi_sqrt
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
True
@
staticmethod
def
_use_rocm_unified_kv_layout
(
block_size
:
int
|
None
=
None
,
key_cache
:
torch
.
Tensor
|
None
=
None
,
value_cache
:
torch
.
Tensor
|
None
=
None
,
)
->
bool
:
if
not
current_platform
.
is_rocm
():
return
False
if
block_size
is
None
:
if
key_cache
is
not
None
and
value_cache
is
not
None
:
if
key_cache
.
ndim
!=
4
or
value_cache
.
ndim
!=
4
:
return
False
if
key_cache
.
shape
!=
value_cache
.
shape
:
return
False
block_size
=
key_cache
.
shape
[
1
]
else
:
try
:
block_size
=
get_current_vllm_config
().
cache_config
.
block_size
except
Exception
:
return
False
return
block_size
is
not
None
and
block_size
!=
64
and
block_size
%
64
==
0
if
current_platform
.
is_rocm
():
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -163,9 +124,6 @@ class FlashAttentionBackend(AttentionBackend):
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
if
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
(
block_size
):
unified_shape
=
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
unified_shape
,
unified_shape
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
...
...
@@ -178,17 +136,6 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
():
if
cache_layout
!=
"NHD"
:
raise
RuntimeError
(
"ROCm unified KV layout currently supports NHD only."
)
if
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
2
,
3
,
4
),
(
1
,
0
,
2
,
3
,
4
)
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
else
:
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
3
,
2
,
5
),
(
1
,
0
,
4
,
2
,
3
)
...
...
@@ -324,34 +271,8 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
max_num_splits
:
int
=
0
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
qq_bias
:
torch
.
Tensor
|
None
=
None
causal
:
bool
=
True
@
property
def
mm_prefix_range_tensor
(
self
)
->
torch
.
Tensor
|
None
:
if
self
.
mm_prefix_range
is
None
:
return
None
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
device
=
self
.
seq_lens
.
device
range_lists
=
[
self
.
mm_prefix_range
.
get
(
i
,
[(
0
,
0
)])
or
[(
0
,
0
)]
for
i
in
range
(
num_seqs
)
]
if
all
(
r
==
[(
0
,
0
)]
for
r
in
range_lists
):
return
None
range_tensors
=
[
torch
.
tensor
(
r
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
-
1
,
2
)
for
r
in
range_lists
]
return
torch
.
nested
.
nested_tensor
(
range_tensors
,
layout
=
torch
.
jagged
).
to_padded_tensor
(
0
)
def
_get_sliding_window_configs
(
vllm_config
:
VllmConfig
,
...
...
@@ -676,7 +597,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
str
|
None
=
None
,
sinks
:
torch
.
Tensor
|
None
=
None
,
use_alibi_sqrt
:
bool
=
False
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
...
...
@@ -702,7 +622,6 @@ class FlashAttentionImpl(AttentionImpl):
self
.
attn_type
=
attn_type
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
use_alibi_sqrt
=
use_alibi_sqrt
# Cache the batch invariant result for use in forward passes
self
.
batch_invariant_enabled
=
vllm_is_batch_invariant
()
...
...
@@ -729,14 +648,6 @@ class FlashAttentionImpl(AttentionImpl):
else
False
)
def
_get_unified_extras
(
self
,
attn_metadata
:
FlashAttentionMetadata
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
]:
mm_prefix_range_tensor
=
attn_metadata
.
mm_prefix_range_tensor
qq_bias
=
attn_metadata
.
qq_bias
return
mm_prefix_range_tensor
,
qq_bias
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -863,36 +774,6 @@ class FlashAttentionImpl(AttentionImpl):
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
use_unified_kv_layout
=
(
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
(
key_cache
=
key_cache
,
value_cache
=
value_cache
)
)
if
use_unified_kv_layout
:
mm_prefix_range_tensor
,
qq_bias
=
self
.
_get_unified_extras
(
attn_metadata
)
varlen_fwd_unified
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_seqlens_q
,
seqused_k
=
seqused_k
,
block_table
=
block_table
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
softcap
=
self
.
logits_soft_cap
,
window_size
=
tuple
(
self
.
sliding_window
),
alibi_slopes
=
self
.
alibi_slopes
,
use_alibi_sqrt
=
self
.
use_alibi_sqrt
,
qq_bias
=
qq_bias
,
s_aux
=
self
.
sinks
,
mm_prefix_range
=
mm_prefix_range_tensor
,
return_softmax_lse
=
False
,
out
=
output
[:
num_actual_tokens
],
)
else
:
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
...
...
@@ -1008,24 +889,8 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if
current_platform
.
is_rocm
():
if
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
(
key_cache
=
key_cache
,
value_cache
=
value_cache
,
):
triton_reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
...
...
vllm/v1/attention/ops/triton_unified_attention.py
View file @
b233584a
...
...
@@ -12,6 +12,10 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
try
:
from
flash_attn
import
varlen_fwd_unified
except
Exception
:
varlen_fwd_unified
=
None
logger
=
init_logger
(
__name__
)
float8_info
=
torch
.
finfo
(
current_platform
.
fp8_dtype
())
...
...
@@ -983,6 +987,14 @@ def unified_attention(
or
num_seqs
>
seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
use_fa_unified_2d
=
(
current_platform
.
is_rocm
()
and
varlen_fwd_unified
is
not
None
and
block_size
%
64
==
0
and
head_size
==
256
)
if
not
use_fa_unified_2d
:
# print("Running Triton kernel")
kernel_unified_attention_2d
[
(
total_num_q_blocks
,
...
...
@@ -1038,6 +1050,29 @@ def unified_attention(
BLOCK_M
=
BLOCK_M
,
USE_FP8
=
output_scale
is
not
None
,
)
else
:
# print("Running FA kernel")
varlen_fwd_unified
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
seqused_k
=
seqused_k
,
block_table
=
block_table
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
softcap
=
softcap
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
use_alibi_sqrt
=
use_alibi_sqrt
,
qq_bias
=
qq_bias
,
s_aux
=
sinks
,
mm_prefix_range
=
mm_prefix_range
,
return_softmax_lse
=
False
,
out
=
out
,
)
else
:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d
[
...
...
vllm/v1/attention/selector.py
View file @
b233584a
...
...
@@ -27,7 +27,6 @@ class AttentionSelectorConfig(NamedTuple):
has_sink
:
bool
=
False
use_sparse
:
bool
=
False
use_mm_prefix
:
bool
=
False
use_alibi_sqrt
:
bool
=
False
attn_type
:
str
=
AttentionType
.
DECODER
def
__repr__
(
self
):
...
...
@@ -40,7 +39,6 @@ class AttentionSelectorConfig(NamedTuple):
f
"has_sink=
{
self
.
has_sink
}
, "
f
"use_sparse=
{
self
.
use_sparse
}
, "
f
"use_mm_prefix=
{
self
.
use_mm_prefix
}
, "
f
"use_alibi_sqrt=
{
self
.
use_alibi_sqrt
}
, "
f
"attn_type=
{
self
.
attn_type
}
)"
)
...
...
@@ -54,7 +52,6 @@ def get_attn_backend(
has_sink
:
bool
=
False
,
use_sparse
:
bool
=
False
,
use_mm_prefix
:
bool
=
False
,
use_alibi_sqrt
:
bool
=
False
,
attn_type
:
str
|
None
=
None
,
)
->
type
[
AttentionBackend
]:
"""Selects which attention backend to use and lazily imports it."""
...
...
@@ -80,7 +77,6 @@ def get_attn_backend(
has_sink
=
has_sink
,
use_sparse
=
use_sparse
,
use_mm_prefix
=
use_mm_prefix
,
use_alibi_sqrt
=
use_alibi_sqrt
,
attn_type
=
attn_type
or
AttentionType
.
DECODER
,
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b233584a
...
...
@@ -5958,7 +5958,7 @@ class GPUModelRunner(
return
kv_caches
def
_update_hybrid_attention_mamba_layout
(
self
,
kv_caches
:
dict
[
str
,
Any
]
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
)
->
None
:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
...
...
@@ -5972,8 +5972,6 @@ class GPUModelRunner(
kv_cache_spec
=
group
.
kv_cache_spec
for
layer_name
in
group
.
layer_names
:
kv_cache
=
kv_caches
[
layer_name
]
if
not
isinstance
(
kv_cache
,
torch
.
Tensor
):
continue
if
isinstance
(
kv_cache_spec
,
AttentionSpec
)
and
kv_cache
.
shape
[
0
]
==
2
:
assert
kv_cache
.
shape
[
1
]
!=
2
,
(
"Fail to determine whether the layout is "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment