Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c8f64fa
Unverified
Commit
4c8f64fa
authored
Jun 17, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 17, 2025
Browse files
[V1][Kernel] Flashinfer HND KV cache layout (#19280)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
93aee29f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
20 deletions
+64
-20
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-3
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+5
-4
vllm/envs.py
vllm/envs.py
+11
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-7
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+21
-6
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+21
-0
No files found.
vllm/attention/backends/flashinfer.py
View file @
4c8f64fa
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
dataclasses
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
...
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
envs
.
VLLM_KV_CACHE_LAYOUT
or
"NHD"
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
4c8f64fa
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
"""
"""
KV cache helper for store.
KV cache helper for store.
"""
"""
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -94,15 +93,17 @@ class model_aware_kv_ops_helper:
...
@@ -94,15 +93,17 @@ class model_aware_kv_ops_helper:
def
get_kv_connector_cache_layout
():
def
get_kv_connector_cache_layout
():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
kv_config
=
vllm_config
.
kv_transfer_config
if
vllm_config
.
model_config
is
None
:
if
vllm_config
.
model_config
is
None
or
kv_config
is
None
:
logger
.
warning
(
"Unable to detect current VLLM config. "
\
logger
.
warning
_once
(
"Unable to detect current VLLM config. "
\
"Defaulting to NHD kv cache layout."
)
"Defaulting to NHD kv cache layout."
)
else
:
else
:
use_mla
=
vllm_config
.
model_config
.
use_mla
use_mla
=
vllm_config
.
model_config
.
use_mla
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
logger
.
info
(
"NixlConnector detected. Setting KV cache "
\
logger
.
info
_once
(
"NixlConnector detected. Setting KV cache "
\
"layout to HND for better xfer performance."
)
"layout to HND for better xfer performance."
)
return
"HND"
return
"HND"
return
"NHD"
return
"NHD"
vllm/envs.py
View file @
4c8f64fa
...
@@ -128,6 +128,7 @@ if TYPE_CHECKING:
...
@@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
:
int
=
1
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
:
int
=
1
VLLM_SLEEP_WHEN_IDLE
:
bool
=
False
VLLM_SLEEP_WHEN_IDLE
:
bool
=
False
VLLM_MQ_MAX_CHUNK_BYTES_MB
:
int
=
16
VLLM_MQ_MAX_CHUNK_BYTES_MB
:
int
=
16
VLLM_KV_CACHE_LAYOUT
:
Optional
[
str
]
=
None
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -879,6 +880,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -879,6 +880,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# processes via zmq.
# processes via zmq.
"VLLM_MQ_MAX_CHUNK_BYTES_MB"
:
"VLLM_MQ_MAX_CHUNK_BYTES_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MQ_MAX_CHUNK_BYTES_MB"
,
"16"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_MQ_MAX_CHUNK_BYTES_MB"
,
"16"
)),
# KV Cache layout used throughout vllm.
# Some common values are:
# - NHD
# - HND
# Where N=num_blocks, H=num_heads and D=head_size. The default value will
# leave the layout choice to the backend. Mind that backends may only
# implement and support a subset of all possible layouts.
"VLLM_KV_CACHE_LAYOUT"
:
lambda
:
os
.
getenv
(
"VLLM_KV_CACHE_LAYOUT"
,
None
)
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
4c8f64fa
...
@@ -16,13 +16,12 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
...
@@ -16,13 +16,12 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
get_flash_attn_version
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
,
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -73,16 +72,15 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -73,16 +72,15 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# `stride_order` indicates the permutation that gets
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_
connector_
cache_layout
()
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
else
:
raise
ValueError
(
"Unknown cache layout format
%s."
,
cache_layout
)
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
return
stride_order
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
4c8f64fa
...
@@ -19,7 +19,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
...
@@ -19,7 +19,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
,
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -66,6 +67,19 @@ class FlashInferBackend(AttentionBackend):
...
@@ -66,6 +67,19 @@ class FlashInferBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
stride_order
@
dataclass
@
dataclass
class
PerLayerParameters
:
class
PerLayerParameters
:
...
@@ -290,7 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -290,7 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def
_get_prefill_wrapper
(
self
):
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
def
_get_decode_wrapper
(
self
):
...
@@ -303,14 +317,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -303,14 +317,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_qo_heads
//
num_kv_heads
>
4
)
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
self
.
_get_workspace_buffer
(),
"NHD"
,
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
return
self
.
_decode_wrapper
def
_get_cascade_wrapper
(
self
):
def
_get_cascade_wrapper
(
self
):
if
self
.
_cascade_wrapper
is
None
:
if
self
.
_cascade_wrapper
is
None
:
self
.
_cascade_wrapper
=
MultiLevelCascadeAttentionWrapper
(
self
.
_cascade_wrapper
=
MultiLevelCascadeAttentionWrapper
(
2
,
self
.
_get_workspace_buffer
(),
"NHD"
)
2
,
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
()
)
return
self
.
_cascade_wrapper
return
self
.
_cascade_wrapper
def
_plan
(
self
,
attn_metadata
:
FlashInferMetadata
):
def
_plan
(
self
,
attn_metadata
:
FlashInferMetadata
):
...
@@ -620,6 +634,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -620,6 +634,7 @@ class FlashInferImpl(AttentionImpl):
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
# Regular attention (common case).
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# according to reorder_batch()
...
@@ -634,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -634,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
assert
prefill_wrapper
.
_sm_scale
==
self
.
scale
assert
prefill_wrapper
.
_sm_scale
==
self
.
scale
prefill_wrapper
.
run
(
prefill_wrapper
.
run
(
prefill_query
,
prefill_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
out
=
output
[
num_decode_tokens
:],
out
=
output
[
num_decode_tokens
:],
...
@@ -650,7 +665,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -650,7 +665,7 @@ class FlashInferImpl(AttentionImpl):
assert
decode_wrapper
.
_sm_scale
==
self
.
scale
assert
decode_wrapper
.
_sm_scale
==
self
.
scale
decode_wrapper
.
run
(
decode_wrapper
.
run
(
decode_query
,
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
out
=
output
[:
num_decode_tokens
],
out
=
output
[:
num_decode_tokens
],
...
...
vllm/v1/attention/backends/utils.py
View file @
4c8f64fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
import
abc
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
...
@@ -12,6 +13,13 @@ if TYPE_CHECKING:
...
@@ -12,6 +13,13 @@ if TYPE_CHECKING:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
import
vllm.envs
as
envs
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
@
dataclass
class
CommonAttentionMetadata
:
class
CommonAttentionMetadata
:
...
@@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
...
@@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
raise
ValueError
(
raise
ValueError
(
error_msg
+
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
f
"must be the same type as the current layer (
{
expected
}
)."
)
@
functools
.
lru_cache
def
get_kv_cache_layout
():
# Override with format specified by the user.
cache_layout
=
envs
.
VLLM_KV_CACHE_LAYOUT
if
cache_layout
is
None
:
cache_layout
=
get_kv_connector_cache_layout
()
else
:
logger
.
info_once
(
"`FLASHINFER_KV_CACHE_LAYOUT` environment variable "
\
"detected. Setting KV cache layout to %s."
,
cache_layout
)
return
cache_layout
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment