Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e41f5ab
Unverified
Commit
2e41f5ab
authored
Sep 15, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Sep 15, 2025
Browse files
[XPU] Set consistent default KV cache layout (#24745)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
bc0f6059
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
16 deletions
+23
-16
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+9
-6
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+4
-6
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+10
-4
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
2e41f5ab
...
@@ -56,9 +56,9 @@ except ImportError:
...
@@ -56,9 +56,9 @@ except ImportError:
logger
.
warning
(
"NIXL is not available"
)
logger
.
warning
(
"NIXL is not available"
)
NixlWrapper
=
None
NixlWrapper
=
None
# Supported
xPU
s and types of kv transfer buffer.
# Supported
platform
s and types of kv transfer buffer.
# {
xPU
: tuple of supported kv buffer types}
# {
device
: tuple of supported kv buffer types}
_NIXL_SUPPORTED_
XPUS
=
{
_NIXL_SUPPORTED_
DEVICE
=
{
"cuda"
:
(
"cuda"
,
),
"cuda"
:
(
"cuda"
,
),
"tpu"
:
(
"cpu"
,
),
"tpu"
:
(
"cpu"
,
),
"xpu"
:
(
"cpu"
,
),
"xpu"
:
(
"cpu"
,
),
...
@@ -458,9 +458,9 @@ class NixlConnectorWorker:
...
@@ -458,9 +458,9 @@ class NixlConnectorWorker:
self
.
device_type
=
current_platform
.
device_type
self
.
device_type
=
current_platform
.
device_type
self
.
kv_buffer_device
:
str
=
\
self
.
kv_buffer_device
:
str
=
\
vllm_config
.
kv_transfer_config
.
kv_buffer_device
vllm_config
.
kv_transfer_config
.
kv_buffer_device
if
self
.
device_type
not
in
_NIXL_SUPPORTED_
XPUS
:
if
self
.
device_type
not
in
_NIXL_SUPPORTED_
DEVICE
:
raise
RuntimeError
(
f
"
{
self
.
device_type
}
is not supported."
)
raise
RuntimeError
(
f
"
{
self
.
device_type
}
is not supported."
)
elif
self
.
kv_buffer_device
not
in
_NIXL_SUPPORTED_
XPUS
[
elif
self
.
kv_buffer_device
not
in
_NIXL_SUPPORTED_
DEVICE
[
self
.
device_type
]:
self
.
device_type
]:
raise
RuntimeError
(
raise
RuntimeError
(
f
"
{
self
.
device_type
}
with
{
self
.
kv_buffer_device
}
kv_buffer "
f
"
{
self
.
device_type
}
with
{
self
.
kv_buffer_device
}
kv_buffer "
...
@@ -468,7 +468,7 @@ class NixlConnectorWorker:
...
@@ -468,7 +468,7 @@ class NixlConnectorWorker:
self
.
device_kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
device_kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
# cpu kv buffer for xfer
# cpu kv buffer for xfer
# used when
xPU
memory can not be registered under nixl
# used when
device
memory can not be registered under nixl
self
.
host_xfer_buffers
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
host_xfer_buffers
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
use_host_buffer
=
self
.
kv_buffer_device
==
"cpu"
self
.
use_host_buffer
=
self
.
kv_buffer_device
==
"cpu"
if
self
.
kv_buffer_device
==
"cuda"
:
if
self
.
kv_buffer_device
==
"cuda"
:
...
@@ -927,6 +927,9 @@ class NixlConnectorWorker:
...
@@ -927,6 +927,9 @@ class NixlConnectorWorker:
if
tp_ratio
>
1
:
if
tp_ratio
>
1
:
# Heterogeneous TP expects same kv_cache_layout.
# Heterogeneous TP expects same kv_cache_layout.
assert
nixl_agent_meta
.
kv_cache_layout
==
self
.
kv_cache_layout
assert
nixl_agent_meta
.
kv_cache_layout
==
self
.
kv_cache_layout
if
self
.
device_type
==
"xpu"
:
raise
ValueError
(
"Heterogeneous TP is not supported on XPU"
)
assert
nixl_agent_meta
.
block_len
==
self
.
block_len
*
tp_ratio
,
(
assert
nixl_agent_meta
.
block_len
==
self
.
block_len
*
tp_ratio
,
(
"Remote P worker KV layer cache must be of shape [2, N, "
"Remote P worker KV layer cache must be of shape [2, N, "
...
...
vllm/platforms/xpu.py
View file @
2e41f5ab
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
vllm.v1.attention.backends.utils
import
set_kv_cache_layout
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
...
@@ -164,12 +165,9 @@ class XPUPlatform(Platform):
...
@@ -164,12 +165,9 @@ class XPUPlatform(Platform):
vllm_config
.
scheduler_config
.
max_model_len
,
vllm_config
.
scheduler_config
.
max_model_len
,
DEFAULT_MAX_NUM_BATCHED_TOKENS
)
DEFAULT_MAX_NUM_BATCHED_TOKENS
)
if
(
envs
.
VLLM_KV_CACHE_LAYOUT
is
None
set_kv_cache_layout
(
"NHD"
)
or
envs
.
VLLM_KV_CACHE_LAYOUT
!=
"NHD"
):
logger
.
info
(
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
os
.
environ
[
"VLLM_KV_CACHE_LAYOUT"
]
=
"NHD"
"only NHD layout is supported by XPU attention kernels."
)
logger
.
info
(
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
"only NHD layout is supported by XPU attention kernels."
)
@
classmethod
@
classmethod
def
is_pin_memory_available
(
cls
):
def
is_pin_memory_available
(
cls
):
...
...
vllm/v1/attention/backends/utils.py
View file @
2e41f5ab
...
@@ -5,8 +5,8 @@ import enum
...
@@ -5,8 +5,8 @@ import enum
import
functools
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
,
fields
,
make_dataclass
from
dataclasses
import
dataclass
,
fields
,
make_dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Optional
,
Protoco
l
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Literal
,
Optiona
l
,
TypeVar
)
Protocol
,
TypeVar
,
Union
,
get_args
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -30,7 +30,12 @@ from vllm.logger import init_logger
...
@@ -30,7 +30,12 @@ from vllm.logger import init_logger
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_KV_CACHE_LAYOUT_OVERRIDE
=
None
KVCacheLayoutType
=
Literal
[
"NHD"
,
"HND"
]
_KV_CACHE_LAYOUT_OVERRIDE
:
Union
[
KVCacheLayoutType
,
None
]
=
None
def
is_valid_kv_cache_layout
(
value
:
str
)
->
bool
:
return
value
in
get_args
(
KVCacheLayoutType
)
@
dataclass
@
dataclass
...
@@ -296,12 +301,13 @@ def get_kv_cache_layout():
...
@@ -296,12 +301,13 @@ def get_kv_cache_layout():
if
cache_layout
is
None
:
if
cache_layout
is
None
:
cache_layout
=
get_kv_connector_cache_layout
()
cache_layout
=
get_kv_connector_cache_layout
()
else
:
else
:
assert
is_valid_kv_cache_layout
(
cache_layout
)
logger
.
info_once
(
"`VLLM_KV_CACHE_LAYOUT` environment variable "
\
logger
.
info_once
(
"`VLLM_KV_CACHE_LAYOUT` environment variable "
\
"detected. Setting KV cache layout to %s."
,
cache_layout
)
"detected. Setting KV cache layout to %s."
,
cache_layout
)
return
cache_layout
return
cache_layout
def
set_kv_cache_layout
(
cache_layout
:
str
):
def
set_kv_cache_layout
(
cache_layout
:
KVCacheLayoutType
):
global
_KV_CACHE_LAYOUT_OVERRIDE
global
_KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE
=
cache_layout
_KV_CACHE_LAYOUT_OVERRIDE
=
cache_layout
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment