Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
94578127
Unverified
Commit
94578127
authored
Jan 09, 2026
by
Chendi.Xue
Committed by
GitHub
Jan 09, 2026
Browse files
[NIXL] refine decoder side post process for heterogeneous BlockSize and kv_layout (#30275)
parent
2612ba92
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
87 deletions
+139
-87
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+78
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+61
-87
No files found.
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
94578127
...
@@ -203,6 +203,84 @@ def copy_kv_blocks(
...
@@ -203,6 +203,84 @@ def copy_kv_blocks(
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
def
kv_postprocess_blksize_on_receive
(
cache
,
indices
,
block_size_ratio
):
"""
Transforms the layout of received KV cache blocks to the local block_size.
(Only works for local blocksize > remote blocksize)
example:
local blocksize = 16 tokens, remote blocksize = 4 tokens
local block[0] = remote block[0, 1, 2, 3]
remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
local is |h0-b0..................|h1-b0..................|...
permute is to:
1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
2. permute => (H, nblocks, remoteN, D)
3. flatten => (H, localN, D)
"""
blocks_to_update
=
cache
.
index_select
(
0
,
indices
)
# use physical order
blocks_to_update
=
blocks_to_update
.
permute
(
0
,
2
,
1
,
3
)
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
remote_block_size
=
block_size
//
block_size_ratio
n_blocks
=
block_size_ratio
permuted_blocks
=
(
blocks_to_update
.
reshape
(
-
1
,
n_blocks
,
n_kv_heads
,
remote_block_size
,
head_size
)
.
permute
(
0
,
2
,
1
,
3
,
4
)
.
flatten
(
2
,
3
)
)
permuted_blocks
=
permuted_blocks
.
permute
(
0
,
2
,
1
,
3
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
kv_postprocess_layout_on_receive
(
cache
,
indices
):
"""Transforms the layout of received KV cache blocks to the local format.
This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
blocks_to_update
=
cache
.
index_select
(
0
,
indices
)
target_shape
=
list
(
blocks_to_update
.
shape
)
target_shape
[
0
]
=
-
1
inv_order
=
[
0
,
2
,
1
,
3
]
src_shape
=
tuple
(
target_shape
[
i
]
for
i
in
inv_order
)
blocks_to_update
=
cache
.
index_select
(
0
,
indices
)
permuted_blocks
=
blocks_to_update
.
reshape
(
src_shape
).
permute
(
*
inv_order
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
kv_postprocess_blksize_and_layout_on_receive
(
cache
,
indices
,
block_size_ratio
):
"""
Transforms the layout of received KV cache to the local block_size and HND.
(Only works for local blocksize > remote blocksize)
prefill is HND, smaller block_size
decode(local) is NHD, larger block_size
"""
blocks_to_update
=
cache
.
index_select
(
0
,
indices
)
block_size
,
n_kv_heads
,
head_size
=
blocks_to_update
.
shape
[
1
:]
remote_block_size
=
block_size
//
block_size_ratio
n_blocks
=
block_size_ratio
permuted_blocks
=
(
blocks_to_update
.
reshape
(
-
1
,
n_blocks
,
n_kv_heads
,
remote_block_size
,
head_size
)
.
permute
(
0
,
1
,
3
,
2
,
4
)
.
flatten
(
1
,
2
)
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
yield_req_data
(
def
yield_req_data
(
scheduler_output
,
scheduler_output
,
)
->
Iterator
[
tuple
[
str
,
tuple
[
list
[
int
],
...],
bool
]]:
)
->
Iterator
[
tuple
[
str
,
tuple
[
list
[
int
],
...],
bool
]]:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
94578127
...
@@ -24,6 +24,9 @@ from vllm.config import VllmConfig
...
@@ -24,6 +24,9 @@ from vllm.config import VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
EngineId
,
EngineId
,
TpKVTopology
,
TpKVTopology
,
kv_postprocess_blksize_and_layout_on_receive
,
kv_postprocess_blksize_on_receive
,
kv_postprocess_layout_on_receive
,
yield_req_data
,
yield_req_data
,
)
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
...
@@ -1749,88 +1752,62 @@ class NixlConnectorWorker:
...
@@ -1749,88 +1752,62 @@ class NixlConnectorWorker:
"d2h"
,
"d2h"
,
)
)
def
permute_device_kv
(
self
,
block_ids
:
list
[
int
]):
def
post_process_device_kv_on_receive
(
"""Transforms the layout of received KV cache blocks to the local format.
self
,
block_size_ratio
:
int
,
This method corrects layout mismatches from direct memory copies by
block_ids_list
:
list
[
list
[
int
]],
permuting the tensor dimensions.
):
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
Args:
block_ids: A list of block IDs to update and permute.
Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
"""
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
Post process device kv cache after receiving from remote.
inv_order
=
[
0
,
2
,
1
,
3
]
sample_cache
=
list
(
self
.
device_kv_caches
.
values
())[
0
][
0
]
target_shape
=
list
(
sample_cache
.
shape
)
target_shape
[
0
]
=
-
1
src_shape
=
tuple
(
target_shape
[
i
]
for
i
in
inv_order
)
indices
=
torch
.
tensor
(
block_ids
,
device
=
sample_cache
.
device
)
for
_
,
cache_or_caches
in
self
.
device_kv_caches
.
items
():
3 types of post processing supported:
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
* kv_cache_postprocess_layout => convert from HND to NHD
for
cache
in
cache_list
:
* kv_cache_postprocess_blksize => convert from small block size
blocks_to_update
=
cache
.
index_select
(
0
,
indices
)
to large block size
permuted_blocks
=
blocks_to_update
.
reshape
(
src_shape
).
permute
(
* kv_cache_postprocess_blksize_and_layout => convert from small
*
inv_order
block size to large block size and convert from HND to NHD
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
int
,
list
[
list
[
int
]]]):
def
_process_local_gt_remote
(
blocks_to_update
,
block_size_ratio
):
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
remote_block_size
=
block_size
//
block_size_ratio
n_blocks
=
block_size_ratio
# actual permute is to convert
# for local blocksize > remote blocksize
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
# local block[0] = remote block[0, 1, 2, 3]
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
# local is |h0-b0..................|h1-b0..................|...
# permute is to:
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
# 2. permute => (H, nblocks, remoteN, D)
# 3. flatten => (H, localN, D)
permuted_blocks
=
(
blocks_to_update
.
reshape
(
-
1
,
n_blocks
,
n_kv_heads
,
remote_block_size
,
head_size
)
.
permute
(
0
,
2
,
1
,
3
,
4
)
.
flatten
(
2
,
3
)
)
return
permuted_blocks
"""
if
len
(
self
.
device_kv_caches
)
==
0
:
if
len
(
self
.
device_kv_caches
)
==
0
:
return
return
assert
block_size_ratio
>=
1
,
"Only nP < nD supported currently."
if
self
.
enable_permute_local_kv
and
block_size_ratio
>
1
:
logger
.
debug
(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger and permuting layout from HND"
" to NHD."
,
block_size_ratio
,
)
elif
self
.
enable_permute_local_kv
:
logger
.
debug
(
"Post-processing device kv cache on receive by permuting layout"
"from HND to NHD."
)
else
:
logger
.
debug
(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger."
,
block_size_ratio
,
)
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
kv_topo
.
is_kv_layout_blocks_first
)
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
kv_topo
.
is_kv_layout_blocks_first
)
sample_cache
=
list
(
self
.
device_kv_caches
.
values
())[
0
][
0
]
for
block_size_ratio
,
block_ids_list
in
block_ids_per_ratio
.
items
():
assert
block_size_ratio
>
1
,
"Only nP < nD supported currently."
block_ids_list
=
[[
item
for
sublist
in
block_ids_list
for
item
in
sublist
]]
for
block_ids
in
block_ids_list
:
for
block_ids
in
block_ids_list
:
indices
=
torch
.
tensor
(
block_ids
,
device
=
s
ample_cache
.
device
)
indices
=
torch
.
tensor
(
block_ids
,
device
=
s
elf
.
device_type
,
dtype
=
torch
.
long
)
for
_
,
cache_or_caches
in
self
.
device_kv_caches
.
items
():
for
_
,
cache_or_caches
in
self
.
device_kv_caches
.
items
():
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
for
cache
in
cache_list
:
for
cache
in
cache_list
:
blocks_to_update
=
cache
.
index_select
(
0
,
indices
)
if
self
.
enable_permute_local_kv
and
block_size_ratio
>
1
:
# because kv_cache is always using original layout NHD as
kv_postprocess_blksize_and_layout_on_receive
(
# virtual shape while stride can be either HND / NHD at
cache
,
indices
,
block_size_ratio
# initialization.
)
# we need to firstly get physical view of the tensor
elif
self
.
enable_permute_local_kv
:
permuted_blocks
=
_process_local_gt_remote
(
kv_postprocess_layout_on_receive
(
cache
,
indices
)
blocks_to_update
.
permute
(
0
,
2
,
1
,
3
),
block_size_ratio
else
:
).
permute
(
0
,
2
,
1
,
3
)
kv_postprocess_blksize_on_receive
(
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
cache
,
indices
,
block_size_ratio
)
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
"""
"""
...
@@ -1854,7 +1831,6 @@ class NixlConnectorWorker:
...
@@ -1854,7 +1831,6 @@ class NixlConnectorWorker:
len
(
done_recving
),
len
(
done_recving
),
)
)
block_ids_to_permute
=
[]
block_ids_for_blocksize_post_process
=
defaultdict
(
list
)
block_ids_for_blocksize_post_process
=
defaultdict
(
list
)
for
req_id
in
done_recving
:
for
req_id
in
done_recving
:
# clean up metadata for completed requests
# clean up metadata for completed requests
...
@@ -1863,24 +1839,22 @@ class NixlConnectorWorker:
...
@@ -1863,24 +1839,22 @@ class NixlConnectorWorker:
assert
meta
.
remote
is
not
None
assert
meta
.
remote
is
not
None
if
self
.
use_host_buffer
:
if
self
.
use_host_buffer
:
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
if
self
.
enable_permute_local_kv
:
block_ids_to_permute
+=
meta
.
local_physical_block_ids
# post processing for heteroblocksize
# post processing for heteroblocksize
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
meta
.
remote
.
engine_id
meta
.
remote
.
engine_id
)
)
if
(
if
not
self
.
use_mla
and
(
not
self
.
use_mla
block_size_ratio
>
1
or
self
.
enable_permute_local_kv
and
block_size_ratio
>
1
and
self
.
kv_cache_layout
==
"HND"
):
):
block_ids_for_blocksize_post_process
[
block_size_ratio
].
append
(
block_ids_for_blocksize_post_process
[
block_size_ratio
].
append
(
meta
.
local_block_ids
meta
.
local_
physical_
block_ids
)
)
self
.
blocksize_post_process
(
block_ids_for_blocksize_post_process
)
for
(
if
len
(
block_ids_to_permute
)
>
0
:
block_size_ratio
,
self
.
permute_device_kv
(
block_ids_to_permute
)
block_ids_list
,
)
in
block_ids_for_blocksize_post_process
.
items
():
self
.
post_process_device_kv_on_receive
(
block_size_ratio
,
block_ids_list
)
# Handle timeout to avoid stranding blocks on remote.
# Handle timeout to avoid stranding blocks on remote.
now
=
time
.
perf_counter
()
now
=
time
.
perf_counter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment