Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
864c718a
Commit
864c718a
authored
Aug 19, 2025
by
zhuwenwen
Browse files
update v1 fa layout
parent
693d5ed4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
168 additions
and
71 deletions
+168
-71
vllm/attention/layer.py
vllm/attention/layer.py
+2
-2
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+2
-2
vllm/config.py
vllm/config.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+83
-39
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+80
-27
No files found.
vllm/attention/layer.py
View file @
864c718a
...
@@ -75,7 +75,7 @@ class Attention(nn.Module):
...
@@ -75,7 +75,7 @@ class Attention(nn.Module):
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
else
:
else
:
kv_cache_dtype
=
"auto"
kv_cache_dtype
=
"auto"
block_size
=
1
6
if
not
envs
.
VLLM_USE_FLASH_ATTN_PA
else
6
4
block_size
=
6
4
if
envs
.
VLLM_USE_FLASH_ATTN_PA
or
envs
.
VLLM_USE_FLASH_MLA
else
1
6
is_attention_free
=
False
is_attention_free
=
False
calculate_kv_scales
=
False
calculate_kv_scales
=
False
if
num_kv_heads
is
None
:
if
num_kv_heads
is
None
:
...
@@ -303,7 +303,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -303,7 +303,7 @@ class MultiHeadAttention(nn.Module):
attn_backend
=
get_attn_backend
(
head_size
,
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
dtype
,
kv_cache_dtype
=
None
,
kv_cache_dtype
=
None
,
block_size
=
1
6
if
not
envs
.
VLLM_USE_FLASH_ATTN_PA
else
6
4
,
block_size
=
6
4
if
envs
.
VLLM_USE_FLASH_ATTN_PA
or
envs
.
VLLM_USE_FLASH_MLA
else
1
6
,
is_attention_free
=
False
)
is_attention_free
=
False
)
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
...
vllm/attention/utils/fa_utils.py
View file @
864c718a
...
@@ -15,8 +15,8 @@ if current_platform.is_cuda():
...
@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata
)
get_scheduler_metadata
)
elif
current_platform
.
is_rocm
():
elif
current_platform
.
is_rocm
():
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
reshape_and_cache_
flash
=
ops
.
reshape_and_cache_
flash
reshape_and_cache_
cuda
=
ops
.
reshape_and_cache_
cuda
from
flash_attn
import
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
from
flash_attn
import
vllm_flash_attn_varlen_func
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
...
...
vllm/config.py
View file @
864c718a
...
@@ -1497,7 +1497,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
...
@@ -1497,7 +1497,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class
CacheConfig
:
class
CacheConfig
:
"""Configuration for the KV cache."""
"""Configuration for the KV cache."""
block_size
:
BlockSize
=
1
6
if
not
envs
.
VLLM_USE_FLASH_ATTN_PA
else
6
4
# type: ignore
block_size
:
BlockSize
=
6
4
if
envs
.
VLLM_USE_FLASH_ATTN_PA
or
envs
.
VLLM_USE_FLASH_MLA
else
1
6
# type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
864c718a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
...
@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata
,
get_scheduler_metadata
,
reshape_and_cache_flash
)
reshape_and_cache_flash
)
else
:
else
:
from
vllm.attention.utils.fa_utils
import
(
flash_attn_varlen_func
,
from
vllm.attention.utils.fa_utils
import
(
vllm_flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
)
reshape_and_cache_flash
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
...
@@ -83,6 +82,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -83,6 +82,7 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
return
FlashAttentionMetadataBuilder
if
not
current_platform
.
is_rocm
():
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -106,6 +106,35 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -106,6 +106,35 @@ class FlashAttentionBackend(AttentionBackend):
else
:
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
return
stride_order
else
:
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
2
,
1
,
3
)
value_stride_order
=
(
0
,
3
,
1
,
2
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
@
dataclass
@
dataclass
...
@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
key_cache
,
value_cache
=
kv_cache
if
self
.
kv_sharing_target_layer_name
is
None
:
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
...
@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# actual tokens.
if
not
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
reshape_and_cache_flash
(
key
,
key
,
value
,
value
,
...
@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
else
:
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
...
@@ -618,7 +662,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -618,7 +662,7 @@ class FlashAttentionImpl(AttentionImpl):
# k_descale=layer._k_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache
=
Fals
e
,
is_prefix_cache
=
Tru
e
,
)
)
return
output
return
output
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
864c718a
...
@@ -2494,6 +2494,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2494,6 +2494,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec
.
page_size_bytes
)
kv_cache_spec
.
page_size_bytes
)
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
has_attn
=
True
has_attn
=
True
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
not
kv_cache_spec
.
use_mla
:
key_cache_shape
,
value_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
try
:
key_stride_order
,
value_stride_order
=
self
.
attn_backends
[
i
].
get_kv_cache_stride_order
()
assert
len
(
key_stride_order
)
==
len
(
key_cache_shape
)
assert
len
(
value_stride_order
)
==
len
(
value_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
key_stride_order
=
tuple
(
range
(
len
(
key_cache_shape
)))
value_stride_order
=
tuple
(
range
(
len
(
value_cache_shape
)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
key_cache_shape
=
tuple
(
key_cache_shape
[
i
]
for
i
in
key_stride_order
)
value_cache_shape
=
tuple
(
value_cache_shape
[
i
]
for
i
in
value_stride_order
)
# Maintain original KV shape view.
inv_key_order
=
[
key_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
key_stride_order
))
]
inv_value_order
=
[
value_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
value_stride_order
))
]
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
)
total_elements
=
raw_tensor
.
numel
()
key_elements
=
(
key_cache_shape
[
0
]
*
key_cache_shape
[
1
]
*
key_cache_shape
[
2
]
*
key_cache_shape
[
3
])
value_elements
=
(
value_cache_shape
[
0
]
*
value_cache_shape
[
1
]
*
value_cache_shape
[
2
]
*
value_cache_shape
[
3
])
assert
total_elements
==
key_elements
+
value_elements
key_cache
=
raw_tensor
[:
key_elements
].
view
(
key_cache_shape
).
permute
(
*
inv_key_order
)
value_cache
=
raw_tensor
[
key_elements
:].
view
(
value_cache_shape
).
permute
(
*
inv_value_order
)
kv_caches
[
layer_name
]
=
(
key_cache
,
value_cache
)
else
:
kv_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
kv_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment