Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
981f8370
Commit
981f8370
authored
Aug 22, 2025
by
zhuwenwen
Browse files
update v1 fa layout and set v1 attention use fa
parent
c0f0b209
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
165 additions
and
70 deletions
+165
-70
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+84
-40
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+78
-27
No files found.
vllm/attention/utils/fa_utils.py
View file @
981f8370
...
...
@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata
)
elif
current_platform
.
is_rocm
():
from
vllm
import
_custom_ops
as
ops
reshape_and_cache_
flash
=
ops
.
reshape_and_cache_
flash
from
flash_attn
import
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
reshape_and_cache_
cuda
=
ops
.
reshape_and_cache_
cuda
from
flash_attn
import
vllm_flash_attn_varlen_func
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
...
...
vllm/platforms/rocm.py
View file @
981f8370
...
...
@@ -289,7 +289,7 @@ class RocmPlatform(Platform):
# logger.info_once("Using Triton backend on V1 engine.")
# return TRITON_ATTN_VLLM_V1
if
envs
.
is_set
(
"VLLM_USE_FLASH_ATTN_PA"
)
and
envs
.
VLLM_USE_FLASH_ATTN_PA
and
block_size
==
64
:
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
block_size
==
64
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine. (only supports block size 64)"
)
return
FLASH_ATTN_V1
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
981f8370
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
torch
...
...
@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata
,
reshape_and_cache_flash
)
else
:
from
vllm.attention.utils.fa_utils
import
(
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
,
reshape_and_cache_flash
)
from
vllm.attention.utils.fa_utils
import
(
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
...
...
@@ -84,6 +83,7 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
if
not
current_platform
.
is_rocm
():
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
...
...
@@ -107,6 +107,35 @@ class FlashAttentionBackend(AttentionBackend):
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
else
:
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
2
,
3
)
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
2
,
1
,
3
)
value_stride_order
=
(
0
,
2
,
1
,
3
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
@
staticmethod
def
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
...
...
@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata
,
layer
)
# For decoder and cross-attention, use KV cache as before
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
key_cache
,
value_cache
=
kv_cache
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
...
...
@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if
not
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
key
,
value
,
...
...
@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
...
...
@@ -582,7 +626,7 @@ class FlashAttentionImpl(AttentionImpl):
else
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA SIZE:"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
unsqueeze
(
1
).
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
...
...
@@ -607,7 +651,7 @@ class FlashAttentionImpl(AttentionImpl):
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks,
is_prefix_cache
=
Fals
e
,
is_prefix_cache
=
Tru
e
,
)
return
output
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
981f8370
...
...
@@ -3095,6 +3095,57 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_spec
.
page_size_bytes
)
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
has_attn
=
True
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
not
kv_cache_spec
.
use_mla
:
key_cache_shape
,
value_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
try
:
key_stride_order
,
value_stride_order
=
attn_backend
.
get_kv_cache_stride_order
()
assert
len
(
key_stride_order
)
==
len
(
key_cache_shape
)
assert
len
(
value_stride_order
)
==
len
(
value_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
key_stride_order
=
tuple
(
range
(
len
(
key_cache_shape
)))
value_stride_order
=
tuple
(
range
(
len
(
value_cache_shape
)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
key_cache_shape
=
tuple
(
key_cache_shape
[
i
]
for
i
in
key_stride_order
)
value_cache_shape
=
tuple
(
value_cache_shape
[
i
]
for
i
in
value_stride_order
)
# Maintain original KV shape view.
inv_key_order
=
[
key_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
key_stride_order
))
]
inv_value_order
=
[
value_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
value_stride_order
))
]
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
)
total_elements
=
raw_tensor
.
numel
()
key_elements
=
(
key_cache_shape
[
0
]
*
key_cache_shape
[
1
]
*
key_cache_shape
[
2
]
*
key_cache_shape
[
3
])
value_elements
=
(
value_cache_shape
[
0
]
*
value_cache_shape
[
1
]
*
value_cache_shape
[
2
]
*
value_cache_shape
[
3
])
assert
total_elements
==
key_elements
+
value_elements
key_cache
=
raw_tensor
[:
key_elements
].
view
(
key_cache_shape
).
permute
(
*
inv_key_order
)
value_cache
=
raw_tensor
[
key_elements
:].
view
(
value_cache_shape
).
permute
(
*
inv_value_order
)
kv_caches
[
layer_name
]
=
(
key_cache
,
value_cache
)
else
:
kv_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment