Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c8f64fa
Unverified
Commit
4c8f64fa
authored
Jun 17, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 17, 2025
Browse files
[V1][Kernel] Flashinfer HND KV cache layout (#19280)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
93aee29f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
20 deletions
+64
-20
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-3
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+5
-4
vllm/envs.py
vllm/envs.py
+11
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-7
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+21
-6
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+21
-0
No files found.
vllm/attention/backends/flashinfer.py
View file @
4c8f64fa
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
os
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
...
...
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
envs
.
VLLM_KV_CACHE_LAYOUT
or
"NHD"
class
FlashInferBackend
(
AttentionBackend
):
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
4c8f64fa
...
...
@@ -3,7 +3,6 @@
"""
KV cache helper for store.
"""
import
torch
import
vllm.envs
as
envs
...
...
@@ -94,15 +93,17 @@ class model_aware_kv_ops_helper:
def
get_kv_connector_cache_layout
():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
if
vllm_config
.
model_config
is
None
:
logger
.
warning
(
"Unable to detect current VLLM config. "
\
if
vllm_config
.
model_config
is
None
or
kv_config
is
None
:
logger
.
warning
_once
(
"Unable to detect current VLLM config. "
\
"Defaulting to NHD kv cache layout."
)
else
:
use_mla
=
vllm_config
.
model_config
.
use_mla
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
logger
.
info
(
"NixlConnector detected. Setting KV cache "
\
logger
.
info
_once
(
"NixlConnector detected. Setting KV cache "
\
"layout to HND for better xfer performance."
)
return
"HND"
return
"NHD"
vllm/envs.py
View file @
4c8f64fa
...
...
@@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
:
int
=
1
VLLM_SLEEP_WHEN_IDLE
:
bool
=
False
VLLM_MQ_MAX_CHUNK_BYTES_MB
:
int
=
16
VLLM_KV_CACHE_LAYOUT
:
Optional
[
str
]
=
None
def
get_default_cache_root
():
...
...
@@ -879,6 +880,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# processes via zmq.
"VLLM_MQ_MAX_CHUNK_BYTES_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MQ_MAX_CHUNK_BYTES_MB"
,
"16"
)),
# KV Cache layout used throughout vllm.
# Some common values are:
# - NHD
# - HND
# Where N=num_blocks, H=num_heads and D=head_size. The default value will
# leave the layout choice to the backend. Mind that backends may only
# implement and support a subset of all possible layouts.
"VLLM_KV_CACHE_LAYOUT"
:
lambda
:
os
.
getenv
(
"VLLM_KV_CACHE_LAYOUT"
,
None
)
}
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
4c8f64fa
...
...
@@ -16,13 +16,12 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
,
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -73,16 +72,15 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_
connector_
cache_layout
()
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
"Unknown cache layout format
%s."
,
cache_layout
)
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
4c8f64fa
...
...
@@ -19,7 +19,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
,
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -66,6 +67,19 @@ class FlashInferBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
@
dataclass
class
PerLayerParameters
:
...
...
@@ -290,7 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
...
...
@@ -303,14 +317,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
,
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
def
_get_cascade_wrapper
(
self
):
if
self
.
_cascade_wrapper
is
None
:
self
.
_cascade_wrapper
=
MultiLevelCascadeAttentionWrapper
(
2
,
self
.
_get_workspace_buffer
(),
"NHD"
)
2
,
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
()
)
return
self
.
_cascade_wrapper
def
_plan
(
self
,
attn_metadata
:
FlashInferMetadata
):
...
...
@@ -620,6 +634,7 @@ class FlashInferImpl(AttentionImpl):
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
...
...
@@ -634,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
assert
prefill_wrapper
.
_sm_scale
==
self
.
scale
prefill_wrapper
.
run
(
prefill_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
out
=
output
[
num_decode_tokens
:],
...
...
@@ -650,7 +665,7 @@ class FlashInferImpl(AttentionImpl):
assert
decode_wrapper
.
_sm_scale
==
self
.
scale
decode_wrapper
.
run
(
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
out
=
output
[:
num_decode_tokens
],
...
...
vllm/v1/attention/backends/utils.py
View file @
4c8f64fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
...
...
@@ -12,6 +13,13 @@ if TYPE_CHECKING:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
import
vllm.envs
as
envs
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
CommonAttentionMetadata
:
...
...
@@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
raise
ValueError
(
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
@
functools
.
lru_cache
def
get_kv_cache_layout
():
# Override with format specified by the user.
cache_layout
=
envs
.
VLLM_KV_CACHE_LAYOUT
if
cache_layout
is
None
:
cache_layout
=
get_kv_connector_cache_layout
()
else
:
logger
.
info_once
(
"`FLASHINFER_KV_CACHE_LAYOUT` environment variable "
\
"detected. Setting KV cache layout to %s."
,
cache_layout
)
return
cache_layout
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment