Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d03e4bf6
Commit
d03e4bf6
authored
Mar 19, 2026
by
laibao
Browse files
feat(attn): ROCm块大小为64倍数(且不等于64)时走FA varlen_fwd_unified
parent
1ea9a3f0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
210 additions
and
146 deletions
+210
-146
vllm/envs.py
vllm/envs.py
+0
-5
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+17
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+136
-49
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+54
-88
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-1
No files found.
vllm/envs.py
View file @
d03e4bf6
...
...
@@ -316,7 +316,6 @@ if TYPE_CHECKING:
VLLM_USE_CUDA_GRAPH_SIZES
:
bool
=
False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
bool
=
False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK
:
bool
=
False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D
:
bool
=
False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING
:
bool
=
False
...
...
@@ -1978,10 +1977,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_ENABLE_RAY_ASYNC_SCHEDULING"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
#If set to 1/True, enable the flash attention unified path.
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/platforms/rocm.py
View file @
d03e4bf6
...
...
@@ -304,9 +304,23 @@ class RocmPlatform(Platform):
f
"is not MLA type while requested for MLA backend."
)
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
block_size
==
64
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine. (only supports block size 64)"
)
use_unified_flash
=
(
block_size
is
not
None
and
block_size
!=
64
and
block_size
%
64
==
0
)
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
(
block_size
==
64
or
use_unified_flash
):
if
use_unified_flash
and
block_size
!=
64
:
logger
.
info_once
(
"Using Flash Attention backend with unified varlen kernel on "
"V1 engine. (block size %d, requires block size divisible by 64)"
,
block_size
,
)
else
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine. "
"(only supports block size 64)"
)
return
AttentionBackendEnum
.
FLASH_ATTN
.
get_path
()
else
:
os
.
environ
[
'VLLM_USE_FLASH_ATTN_PA'
]
=
'0'
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
d03e4bf6
...
...
@@ -33,6 +33,13 @@ if is_flash_attn_varlen_func_available():
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
,
)
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
triton_reshape_and_cache_flash
,
)
try
:
from
flash_attn
import
varlen_fwd_unified
except
Exception
:
varlen_fwd_unified
=
None
else
:
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_sinks
,
...
...
@@ -112,6 +119,30 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
def
_use_rocm_unified_kv_layout
(
block_size
:
int
|
None
=
None
,
key_cache
:
torch
.
Tensor
|
None
=
None
,
value_cache
:
torch
.
Tensor
|
None
=
None
,
)
->
bool
:
if
not
current_platform
.
is_rocm
():
return
False
if
block_size
is
None
:
if
key_cache
is
not
None
and
value_cache
is
not
None
:
if
key_cache
.
ndim
!=
4
or
value_cache
.
ndim
!=
4
:
return
False
if
key_cache
.
shape
!=
value_cache
.
shape
:
return
False
block_size
=
key_cache
.
shape
[
1
]
else
:
try
:
block_size
=
get_current_vllm_config
().
cache_config
.
block_size
except
Exception
:
return
False
return
block_size
is
not
None
and
block_size
!=
64
and
block_size
%
64
==
0
if
current_platform
.
is_rocm
():
@
staticmethod
...
...
@@ -124,6 +155,9 @@ class FlashAttentionBackend(AttentionBackend):
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
if
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
(
block_size
):
unified_shape
=
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
unified_shape
,
unified_shape
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
...
...
@@ -136,20 +170,31 @@ class FlashAttentionBackend(AttentionBackend):
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
3
,
2
,
5
),
(
1
,
0
,
4
,
2
,
3
)
elif
cache_layout
==
"NHD"
:
if
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
():
if
cache_layout
!=
"NHD"
:
raise
RuntimeError
(
"ROCm unified KV layout currently supports NHD only."
)
if
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
2
,
3
,
4
),
(
1
,
0
,
2
,
3
,
4
)
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return
(
1
,
2
,
0
,
3
,
4
),
(
1
,
2
,
0
,
4
,
3
)
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
if
cache_layout
==
"NHD"
and
include_num_layers_dimension
:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return
(
1
,
0
,
3
,
2
,
5
),
(
1
,
0
,
4
,
2
,
3
)
elif
cache_layout
==
"NHD"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
elif
cache_layout
==
"HND"
and
include_num_layers_dimension
:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return
(
1
,
2
,
0
,
3
,
4
),
(
1
,
2
,
0
,
4
,
3
)
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
else
:
@
staticmethod
...
...
@@ -774,30 +819,57 @@ class FlashAttentionImpl(AttentionImpl):
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
# num_splits=attn_metadata.max_num_splits,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
use_unified_kv_layout
=
(
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
(
key_cache
=
key_cache
,
value_cache
=
value_cache
)
)
if
use_unified_kv_layout
:
varlen_fwd_unified
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_seqlens_q
,
seqused_k
=
seqused_k
,
block_table
=
block_table
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
softcap
=
self
.
logits_soft_cap
,
window_size
=
tuple
(
self
.
sliding_window
),
alibi_slopes
=
self
.
alibi_slopes
,
use_alibi_sqrt
=
False
,
qq_bias
=
None
,
s_aux
=
self
.
sinks
,
mm_prefix_range
=
None
,
return_softmax_lse
=
False
,
out
=
output
[:
num_actual_tokens
],
)
else
:
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
sliding_window_size
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
# num_splits=attn_metadata.max_num_splits,
s_aux
=
self
.
sinks
,
is_prefix_cache
=
True
,
)
else
:
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
...
...
@@ -889,21 +961,11 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if
current_platform
.
is_rocm
():
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
else
:
from
vllm.v1.attention.backends.fa_utils
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
if
FlashAttentionBackend
.
_use_rocm_unified_kv_layout
(
key_cache
=
key_cache
,
value_cache
=
value_cache
,
):
triton_reshape_and_cache_flash
(
key
,
value
,
key_cache
,
...
...
@@ -913,6 +975,31 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
else
:
from
vllm.v1.attention.backends.fa_utils
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache_flash
(
key
,
...
...
vllm/v1/attention/ops/triton_unified_attention.py
View file @
d03e4bf6
...
...
@@ -12,11 +12,6 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm
import
envs
try
:
from
flash_attn
import
varlen_fwd_unified
except
Exception
:
varlen_fwd_unified
=
None
logger
=
init_logger
(
__name__
)
float8_info
=
torch
.
finfo
(
current_platform
.
fp8_dtype
())
...
...
@@ -988,90 +983,61 @@ def unified_attention(
or
num_seqs
>
seq_threshold_3D
):
# print(f"[2D Triton] k shape: {k.shape}, v shape: {v.shape}")
if
not
envs
.
VLLM_V1_USE_FA_UNIFIED_ATTN_2D
:
# print("Running Triton kernel")
kernel_unified_attention_2d
[
(
total_num_q_blocks
,
num_kv_heads
,
)
](
output_ptr
=
out
,
query_ptr
=
q
,
key_cache_ptr
=
k
,
value_cache_ptr
=
v
,
sink_ptr
=
sinks
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
qq_bias_ptr
=
qq_bias
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
out_scale
=
1
/
output_scale
if
output_scale
is
not
None
else
1.0
,
softcap
=
softcap
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
q
.
stride
(
0
),
query_stride_1
=
q
.
stride
(
1
),
output_stride_0
=
out
.
stride
(
0
),
output_stride_1
=
out
.
stride
(
1
),
qq_bias_stride_0
=
qq_bias
.
stride
(
0
)
if
use_qq_bias
else
0
,
BLOCK_SIZE
=
block_size
,
TILE_SIZE
=
TILE_SIZE_PREFILL
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_ALIBI_SQRT
=
use_alibi_sqrt
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SINKS
=
(
sinks
is
not
None
),
USE_MM_PREFIX
=
use_mm_prefix
,
MAX_MM_RANGES
=
max_mm_ranges
,
mm_prefix_range_ptr
=
mm_prefix_range
,
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_2
=
k
.
stride
(
2
),
stride_k_cache_3
=
k
.
stride
(
3
),
stride_v_cache_0
=
v
.
stride
(
0
),
stride_v_cache_1
=
v
.
stride
(
1
),
stride_v_cache_2
=
v
.
stride
(
2
),
stride_v_cache_3
=
v
.
stride
(
3
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
USE_FP8
=
output_scale
is
not
None
,
)
else
:
if
varlen_fwd_unified
is
None
:
raise
RuntimeError
(
"flash_attn.varlen_fwd_unified is not available in this flash-attn version"
)
# print("Running FA kernel")
varlen_fwd_unified
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
seqused_k
=
seqused_k
,
block_table
=
block_table
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
softcap
=
softcap
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
use_alibi_sqrt
=
use_alibi_sqrt
,
qq_bias
=
qq_bias
,
s_aux
=
sinks
,
mm_prefix_range
=
mm_prefix_range
,
return_softmax_lse
=
False
,
out
=
out
,
kernel_unified_attention_2d
[
(
total_num_q_blocks
,
num_kv_heads
,
)
](
output_ptr
=
out
,
query_ptr
=
q
,
key_cache_ptr
=
k
,
value_cache_ptr
=
v
,
sink_ptr
=
sinks
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seqused_k
,
alibi_slopes_ptr
=
alibi_slopes
,
qq_bias_ptr
=
qq_bias
,
scale
=
softmax_scale
,
k_scale
=
k_descale
,
v_scale
=
v_descale
,
out_scale
=
1
/
output_scale
if
output_scale
is
not
None
else
1.0
,
softcap
=
softcap
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
q
.
stride
(
0
),
query_stride_1
=
q
.
stride
(
1
),
output_stride_0
=
out
.
stride
(
0
),
output_stride_1
=
out
.
stride
(
1
),
qq_bias_stride_0
=
qq_bias
.
stride
(
0
)
if
use_qq_bias
else
0
,
BLOCK_SIZE
=
block_size
,
TILE_SIZE
=
TILE_SIZE_PREFILL
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_ALIBI_SQRT
=
use_alibi_sqrt
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SINKS
=
(
sinks
is
not
None
),
USE_MM_PREFIX
=
use_mm_prefix
,
MAX_MM_RANGES
=
max_mm_ranges
,
mm_prefix_range_ptr
=
mm_prefix_range
,
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_2
=
k
.
stride
(
2
),
stride_k_cache_3
=
k
.
stride
(
3
),
stride_v_cache_0
=
v
.
stride
(
0
),
stride_v_cache_1
=
v
.
stride
(
1
),
stride_v_cache_2
=
v
.
stride
(
2
),
stride_v_cache_3
=
v
.
stride
(
3
),
query_start_len_ptr
=
cu_seqlens_q
,
BLOCK_Q
=
BLOCK_Q
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
USE_FP8
=
output_scale
is
not
None
,
)
else
:
# print(f"[3D Triton] k shape: {k.shape}, v shape: {v.shape}")
kernel_unified_attention_3d
[
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d03e4bf6
...
...
@@ -5951,7 +5951,7 @@ class GPUModelRunner(
return
kv_caches
def
_update_hybrid_attention_mamba_layout
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
self
,
kv_caches
:
dict
[
str
,
Any
]
)
->
None
:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
...
...
@@ -5965,6 +5965,8 @@ class GPUModelRunner(
kv_cache_spec
=
group
.
kv_cache_spec
for
layer_name
in
group
.
layer_names
:
kv_cache
=
kv_caches
[
layer_name
]
if
not
isinstance
(
kv_cache
,
torch
.
Tensor
):
continue
if
isinstance
(
kv_cache_spec
,
AttentionSpec
)
and
kv_cache
.
shape
[
0
]
==
2
:
assert
kv_cache
.
shape
[
1
]
!=
2
,
(
"Fail to determine whether the layout is "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment