Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96b23b8e
Unverified
Commit
96b23b8e
authored
Nov 14, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Nov 14, 2025
Browse files
[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
433c0f86
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
17 deletions
+73
-17
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+4
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+57
-10
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+12
-5
No files found.
tests/v1/worker/test_gpu_model_runner.py
View file @
96b23b8e
...
@@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
...
@@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
req_index
=
0
req_index
=
0
block_table
.
append_row
(
kvcache_manager_blocks
,
req_index
)
block_table
.
append_row
(
kvcache_manager_blocks
,
req_index
)
# Get expected kernel blocks from the implementation for verification.
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks
=
block_table
.
_map_to_kernel_blocks
(
expected_kernel_blocks
=
block_table
.
map_to_kernel_blocks
(
np
.
array
(
kvcache_manager_blocks
)
np
.
array
(
kvcache_manager_blocks
),
block_table
.
blocks_per_kv_block
,
block_table
.
_kernel_block_arange
,
)
)
# Verify block table state
# Verify block table state
assert
block_table
.
num_blocks_per_row
[
req_index
]
==
len
(
expected_kernel_blocks
)
assert
block_table
.
num_blocks_per_row
[
req_index
]
==
len
(
expected_kernel_blocks
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
96b23b8e
...
@@ -49,6 +49,7 @@ from vllm.platforms import current_platform
...
@@ -49,6 +49,7 @@ from vllm.platforms import current_platform
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
...
@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
@
dataclass
@
dataclass
class
ReqMeta
:
class
ReqMeta
:
local_block_ids
:
list
[
int
]
local_block_ids
:
list
[
int
]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids
:
list
[
int
]
remote_block_ids
:
list
[
int
]
remote_block_ids
:
list
[
int
]
remote_host
:
str
remote_host
:
str
remote_port
:
int
remote_port
:
int
...
@@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
assert
load_remote_cache
^
save_to_host
assert
load_remote_cache
^
save_to_host
_req
=
ReqMeta
(
_req
=
ReqMeta
(
local_block_ids
=
local_block_ids
,
local_block_ids
=
local_block_ids
,
local_physical_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
...
@@ -935,6 +939,7 @@ class NixlConnectorWorker:
...
@@ -935,6 +939,7 @@ class NixlConnectorWorker:
attn_backend
=
backend
,
attn_backend
=
backend
,
)
)
self
.
_use_pallas
=
self
.
kv_topo
.
_use_pallas
self
.
_use_pallas
=
self
.
kv_topo
.
_use_pallas
self
.
_physical_blocks_per_logical_kv_block
=
1
def
_nixl_handshake
(
def
_nixl_handshake
(
self
,
self
,
...
@@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
...
@@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
if
base_addr
in
seen_base_addresses
:
if
base_addr
in
seen_base_addresses
:
continue
continue
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size
=
cache
.
shape
[
-
2
]
if
self
.
use_mla
else
cache
.
shape
[
-
3
]
if
self
.
block_size
!=
kernel_block_size
:
logger
.
info_once
(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. "
,
self
.
block_size
,
kernel_block_size
,
)
self
.
_physical_blocks_per_logical_kv_block
=
(
self
.
block_size
//
kernel_block_size
)
self
.
block_size
=
kernel_block_size
seen_base_addresses
.
append
(
base_addr
)
seen_base_addresses
.
append
(
base_addr
)
curr_tensor_size_bytes
=
cache
.
numel
()
*
cache
.
element_size
()
curr_tensor_size_bytes
=
cache
.
numel
()
*
cache
.
element_size
()
...
@@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
...
@@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
assert
self
.
use_host_buffer
assert
self
.
use_host_buffer
assert
self
.
copy_blocks
is
not
None
assert
self
.
copy_blocks
is
not
None
local_block_ids
=
meta
.
local_block_ids
local_block_ids
=
meta
.
local_
physical_
block_ids
self
.
copy_blocks
(
self
.
copy_blocks
(
self
.
host_xfer_buffers
,
self
.
host_xfer_buffers
,
self
.
device_kv_caches
,
self
.
device_kv_caches
,
...
@@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
...
@@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
"synced recved kv of request[%s] to device kv buffer,"
"synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. "
,
"local_block_ids: %s. "
,
req_id
,
req_id
,
","
.
join
(
map
(
str
,
meta
.
local_block_ids
)),
","
.
join
(
map
(
str
,
local_block_ids
)),
)
)
def
save_kv_to_host
(
self
,
metadata
:
NixlConnectorMetadata
):
def
save_kv_to_host
(
self
,
metadata
:
NixlConnectorMetadata
):
...
@@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
...
@@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
assert
self
.
copy_blocks
is
not
None
assert
self
.
copy_blocks
is
not
None
for
req_id
,
meta
in
metadata
.
reqs_to_save
.
items
():
for
req_id
,
meta
in
metadata
.
reqs_to_save
.
items
():
meta
.
local_physical_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
local_block_ids
)
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
if
logger
.
isEnabledFor
(
logging
.
DEBUG
):
logger
.
debug
(
logger
.
debug
(
"save_load_kv for request[%s] to host xfer buffer."
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. "
,
"local_block_ids: %s. "
,
req_id
,
req_id
,
","
.
join
(
map
(
str
,
meta
.
local_block_ids
)),
","
.
join
(
map
(
str
,
meta
.
local_
physical_
block_ids
)),
)
)
# blocking
# blocking
self
.
copy_blocks
(
self
.
copy_blocks
(
self
.
device_kv_caches
,
self
.
device_kv_caches
,
self
.
host_xfer_buffers
,
self
.
host_xfer_buffers
,
meta
.
local_block_ids
,
meta
.
local_
physical_
block_ids
,
meta
.
local_block_ids
,
meta
.
local_
physical_
block_ids
,
"d2h"
,
"d2h"
,
)
)
...
@@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
...
@@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
if
self
.
use_host_buffer
:
if
self
.
use_host_buffer
:
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
if
self
.
enable_permute_local_kv
:
if
self
.
enable_permute_local_kv
:
block_ids_to_permute
+=
meta
.
local_block_ids
block_ids_to_permute
+=
meta
.
local_
physical_
block_ids
if
len
(
block_ids_to_permute
)
>
0
:
if
len
(
block_ids_to_permute
)
>
0
:
self
.
permute_device_kv
(
block_ids_to_permute
)
self
.
permute_device_kv
(
block_ids_to_permute
)
...
@@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
...
@@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
req_id
,
req_id
,
xfer_state
,
xfer_state
,
)
)
# mark all blocks for this request as invalid
# mark all
(logical)
blocks for this request as invalid
if
meta
:
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
):
if
meta
:
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
):
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
...
@@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
...
@@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step().
We check for these trnxs to complete in each step().
"""
"""
for
req_id
,
meta
in
metadata
.
reqs_to_recv
.
items
():
for
req_id
,
meta
in
metadata
.
reqs_to_recv
.
items
():
meta
.
local_physical_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
local_block_ids
)
meta
.
remote_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
remote_block_ids
)
remote_engine_id
=
meta
.
remote_engine_id
remote_engine_id
=
meta
.
remote_engine_id
logger
.
debug
(
logger
.
debug
(
"start_load_kv for request %s from remote engine %s. "
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. "
,
"Num local_block_ids: %s. Num remote_block_ids: %s. "
,
req_id
,
req_id
,
remote_engine_id
,
remote_engine_id
,
len
(
meta
.
local_block_ids
),
len
(
meta
.
local_
physical_
block_ids
),
len
(
meta
.
remote_block_ids
),
len
(
meta
.
remote_block_ids
),
)
)
# always store metadata for failure recovery
# always store metadata for failure recovery
...
@@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
...
@@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
self
.
_read_blocks
(
self
.
_read_blocks
(
request_id
=
req_id
,
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote_engine_id
,
dst_engine_id
=
meta
.
remote_engine_id
,
local_block_ids
=
meta
.
local_block_ids
,
local_block_ids
=
meta
.
local_
physical_
block_ids
,
remote_block_ids
=
meta
.
remote_block_ids
,
remote_block_ids
=
meta
.
remote_block_ids
,
)
)
...
@@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
...
@@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
"Marking blocks as invalid."
,
"Marking blocks as invalid."
,
request_id
,
request_id
,
)
)
# mark all blocks for this request as invalid
# mark all
(logical)
blocks for this request as invalid
if
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
):
if
meta
:
=
self
.
_recving_metadata
.
get
(
request_id
):
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
self
.
_invalid_block_ids
.
update
(
meta
.
local_block_ids
)
self
.
xfer_stats
.
record_failed_transfer
()
self
.
xfer_stats
.
record_failed_transfer
()
...
@@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
...
@@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
descs_ids
=
region_ids
*
num_blocks
+
block_ids
descs_ids
=
region_ids
*
num_blocks
+
block_ids
return
descs_ids
.
flatten
()
return
descs_ids
.
flatten
()
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
list
[
int
])
->
list
[
int
]:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if
self
.
_physical_blocks_per_logical_kv_block
==
1
:
# Noop when physical and logical block sizes are the same
return
block_ids
block_ids_np
=
np
.
array
(
block_ids
)
block_arange
=
np
.
arange
(
0
,
self
.
_physical_blocks_per_logical_kv_block
).
reshape
(
1
,
-
1
)
return
BlockTable
.
map_to_kernel_blocks
(
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
).
tolist
()
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
):
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
):
"""
"""
Get the block length for one K/V element (K and V have the same size).
Get the block length for one K/V element (K and V have the same size).
...
...
vllm/v1/worker/block_table.py
View file @
96b23b8e
...
@@ -98,7 +98,9 @@ class BlockTable:
...
@@ -98,7 +98,9 @@ class BlockTable:
return
return
if
self
.
use_hybrid_blocks
:
if
self
.
use_hybrid_blocks
:
block_ids
=
self
.
_map_to_kernel_blocks
(
np
.
array
(
block_ids
))
block_ids
=
self
.
map_to_kernel_blocks
(
np
.
array
(
block_ids
),
self
.
blocks_per_kv_block
,
self
.
_kernel_block_arange
)
num_blocks
=
len
(
block_ids
)
num_blocks
=
len
(
block_ids
)
start
=
self
.
num_blocks_per_row
[
row_idx
]
start
=
self
.
num_blocks_per_row
[
row_idx
]
...
@@ -188,7 +190,12 @@ class BlockTable:
...
@@ -188,7 +190,12 @@ class BlockTable:
self
.
block_table
.
gpu
.
fill_
(
0
)
self
.
block_table
.
gpu
.
fill_
(
0
)
self
.
block_table
.
cpu
.
fill_
(
0
)
self
.
block_table
.
cpu
.
fill_
(
0
)
def
_map_to_kernel_blocks
(
self
,
kv_manager_block_ids
:
np
.
ndarray
)
->
np
.
ndarray
:
@
staticmethod
def
map_to_kernel_blocks
(
kv_manager_block_ids
:
np
.
ndarray
,
blocks_per_kv_block
:
int
,
kernel_block_arange
:
np
.
ndarray
,
)
->
np
.
ndarray
:
"""Convert kv_manager_block_id IDs to kernel block IDs.
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
Example:
...
@@ -203,12 +210,12 @@ class BlockTable:
...
@@ -203,12 +210,12 @@ class BlockTable:
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
"""
if
not
self
.
use_hybrid
_block
s
:
if
blocks_per_kv
_block
==
1
:
return
kv_manager_block_ids
return
kv_manager_block_ids
kernel_block_ids
=
(
kernel_block_ids
=
(
kv_manager_block_ids
.
reshape
(
-
1
,
1
)
*
self
.
blocks_per_kv_block
kv_manager_block_ids
.
reshape
(
-
1
,
1
)
*
blocks_per_kv_block
+
self
.
_
kernel_block_arange
+
kernel_block_arange
)
)
return
kernel_block_ids
.
reshape
(
-
1
)
return
kernel_block_ids
.
reshape
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment