Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
864c718a
Commit
864c718a
authored
Aug 19, 2025
by
zhuwenwen
Browse files
update v1 fa layout
parent
693d5ed4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
168 additions
and
71 deletions
+168
-71
vllm/attention/layer.py
vllm/attention/layer.py
+2
-2
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+2
-2
vllm/config.py
vllm/config.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+83
-39
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+80
-27
No files found.
vllm/attention/layer.py
View file @
864c718a
...
...
@@ -75,7 +75,7 @@ class Attention(nn.Module):
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
else
:
kv_cache_dtype
=
"auto"
block_size
=
1
6
if
not
envs
.
VLLM_USE_FLASH_ATTN_PA
else
6
4
block_size
=
6
4
if
envs
.
VLLM_USE_FLASH_ATTN_PA
or
envs
.
VLLM_USE_FLASH_MLA
else
1
6
is_attention_free
=
False
calculate_kv_scales
=
False
if
num_kv_heads
is
None
:
...
...
@@ -303,7 +303,7 @@ class MultiHeadAttention(nn.Module):
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
=
None
,
block_size
=
1
6
if
not
envs
.
VLLM_USE_FLASH_ATTN_PA
else
6
4
,
block_size
=
6
4
if
envs
.
VLLM_USE_FLASH_ATTN_PA
or
envs
.
VLLM_USE_FLASH_MLA
else
1
6
,
is_attention_free
=
False
)
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
if
current_platform
.
is_rocm
():
...
...
vllm/attention/utils/fa_utils.py
View file @
864c718a
...
...
@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata
)
elif
current_platform
.
is_rocm
():
from
vllm
import
_custom_ops
as
ops
reshape_and_cache_
flash
=
ops
.
reshape_and_cache_
flash
from
flash_attn
import
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
reshape_and_cache_
cuda
=
ops
.
reshape_and_cache_
cuda
from
flash_attn
import
vllm_flash_attn_varlen_func
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
reshape_and_cache_flash
=
ops
.
reshape_and_cache_flash
...
...
vllm/config.py
View file @
864c718a
...
...
@@ -1497,7 +1497,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class
CacheConfig
:
"""Configuration for the KV cache."""
block_size
:
BlockSize
=
1
6
if
not
envs
.
VLLM_USE_FLASH_ATTN_PA
else
6
4
# type: ignore
block_size
:
BlockSize
=
6
4
if
envs
.
VLLM_USE_FLASH_ATTN_PA
or
envs
.
VLLM_USE_FLASH_MLA
else
1
6
# type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
864c718a
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
,
Tuple
import
numpy
as
np
import
torch
...
...
@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata
,
reshape_and_cache_flash
)
else
:
from
vllm.attention.utils.fa_utils
import
(
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
,
reshape_and_cache_flash
)
from
vllm.attention.utils.fa_utils
import
(
vllm_flash_attn_varlen_func
,
reshape_and_cache_cuda
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
...
...
@@ -83,30 +82,60 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
if
not
current_platform
.
is_rocm
():
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
else
:
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
(
num_blocks
,
num_kv_heads
,
head_size
,
block_size
),
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
key_stride_order
=
(
0
,
1
,
2
,
3
)
value_stride_order
=
(
0
,
1
,
3
,
2
)
elif
cache_layout
==
"HND"
:
key_stride_order
=
(
0
,
2
,
1
,
3
)
value_stride_order
=
(
0
,
3
,
1
,
2
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
@
dataclass
class
FlashAttentionMetadata
:
...
...
@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
key_cache
,
value_cache
=
kv_cache
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
...
...
@@ -522,16 +554,28 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
not
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
...
...
@@ -618,7 +662,7 @@ class FlashAttentionImpl(AttentionImpl):
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache
=
Fals
e
,
is_prefix_cache
=
Tru
e
,
)
return
output
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
864c718a
...
...
@@ -2494,33 +2494,86 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec
.
page_size_bytes
)
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
has_attn
=
True
kv_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
try
:
kv_cache_stride_order
=
self
.
attn_backends
[
i
].
get_kv_cache_stride_order
()
assert
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_shape
)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape
=
tuple
(
kv_cache_shape
[
i
]
for
i
in
kv_cache_stride_order
)
# Maintain original KV shape view.
inv_order
=
[
kv_cache_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
kv_cache_stride_order
))
]
kv_caches
[
layer_name
]
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
).
view
(
kv_cache_shape
).
permute
(
*
inv_order
)
if
envs
.
VLLM_USE_FLASH_ATTN_PA
and
not
kv_cache_spec
.
use_mla
:
key_cache_shape
,
value_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
try
:
key_stride_order
,
value_stride_order
=
self
.
attn_backends
[
i
].
get_kv_cache_stride_order
()
assert
len
(
key_stride_order
)
==
len
(
key_cache_shape
)
assert
len
(
value_stride_order
)
==
len
(
value_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
key_stride_order
=
tuple
(
range
(
len
(
key_cache_shape
)))
value_stride_order
=
tuple
(
range
(
len
(
value_cache_shape
)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
key_cache_shape
=
tuple
(
key_cache_shape
[
i
]
for
i
in
key_stride_order
)
value_cache_shape
=
tuple
(
value_cache_shape
[
i
]
for
i
in
value_stride_order
)
# Maintain original KV shape view.
inv_key_order
=
[
key_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
key_stride_order
))
]
inv_value_order
=
[
value_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
value_stride_order
))
]
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
)
total_elements
=
raw_tensor
.
numel
()
key_elements
=
(
key_cache_shape
[
0
]
*
key_cache_shape
[
1
]
*
key_cache_shape
[
2
]
*
key_cache_shape
[
3
])
value_elements
=
(
value_cache_shape
[
0
]
*
value_cache_shape
[
1
]
*
value_cache_shape
[
2
]
*
value_cache_shape
[
3
])
assert
total_elements
==
key_elements
+
value_elements
key_cache
=
raw_tensor
[:
key_elements
].
view
(
key_cache_shape
).
permute
(
*
inv_key_order
)
value_cache
=
raw_tensor
[
key_elements
:].
view
(
value_cache_shape
).
permute
(
*
inv_value_order
)
kv_caches
[
layer_name
]
=
(
key_cache
,
value_cache
)
else
:
kv_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
try
:
kv_cache_stride_order
=
self
.
attn_backends
[
i
].
get_kv_cache_stride_order
()
assert
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_shape
)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape
=
tuple
(
kv_cache_shape
[
i
]
for
i
in
kv_cache_stride_order
)
# Maintain original KV shape view.
inv_order
=
[
kv_cache_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
kv_cache_stride_order
))
]
kv_caches
[
layer_name
]
=
kv_cache_raw_tensors
[
layer_name
].
view
(
dtype
).
view
(
kv_cache_shape
).
permute
(
*
inv_order
)
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
has_mamba
=
True
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment