Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5c081d4
Unverified
Commit
f5c081d4
authored
Mar 16, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 16, 2026
Browse files
[PD][Nixl] Add support for hybrid SSM-FA models (#36687)
parent
c88ea833
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
587 additions
and
166 deletions
+587
-166
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+8
-0
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+76
-45
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+112
-0
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+38
-8
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
...v_transfer/kv_connector/v1/mooncake/mooncake_connector.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+351
-112
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
f5c081d4
...
...
@@ -18,11 +18,19 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs
=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
# Select config array based on DP_EP env var
if
[[
-n
"
${
DP_EP
:-}
"
]]
;
then
configs
=(
"
${
dp_ep_configs
[@]
}
"
)
echo
"DP_EP is set, using dp_ep_configs"
elif
[[
-n
"
${
HYBRID_SSM
:-}
"
]]
;
then
configs
=(
"
${
hybrid_ssm_configs
[@]
}
"
)
echo
"HYBRID_SSM is set, using hybrid_ssm_configs."
else
configs
=(
"
${
tp_configs
[@]
}
"
)
fi
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
f5c081d4
...
...
@@ -18,6 +18,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"google/gemma-3-4b-it"
:
0.74
,
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8"
:
0.84
,
}
SIMPLE_PROMPT
=
(
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
f5c081d4
...
...
@@ -53,7 +53,13 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from
vllm.v1.attention.backends.utils
import
set_kv_cache_layout
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
,
KVCacheTensor
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
...
...
@@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer0"
,
"layer1"
,
"layer2"
],
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
),
)
]
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
2
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
...
...
@@ -437,7 +455,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self
.
src_xfer_handles_by_block_size
=
{
self
.
block_size
:
1
}
test_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
test_shape
=
self
.
attn_backend
s
[
0
]
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
self
.
kv_topo
=
TpKVTopology
(
...
...
@@ -447,7 +465,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
remote_block_size
=
self
.
_block_size
,
# shared state
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
self
.
attn_backend
,
attn_backend
s
=
self
.
attn_backend
s
,
tensor_shape
=
test_shape
,
)
...
...
@@ -501,6 +519,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
ssm_sizes
=
(
0
,
0
),
),
remote_tp_rank
=
remote_tp_rank
,
remote_tp_size
=
remote_tp_size
,
...
...
@@ -951,6 +970,7 @@ class TestNixlHandshake:
block_lens
=
worker
.
block_len_per_layer
,
kv_cache_layout
=
mismatched_layout
,
block_size
=
worker
.
block_size
,
ssm_sizes
=
(
0
,
0
),
)
with
pytest
.
raises
(
RuntimeError
):
...
...
@@ -1006,6 +1026,7 @@ class TestNixlHandshake:
block_lens
=
[
i
*
2
for
i
in
worker
.
block_len_per_layer
],
kv_cache_layout
=
"HND"
,
block_size
=
worker
.
block_size
,
ssm_sizes
=
(
0
,
0
),
)
# We don't check layout for homogeneous TP and MLA for now, as the
...
...
@@ -1496,9 +1517,47 @@ def test_register_kv_caches(
# test run if not mocking.
mock_get_attn_backend
.
return_value
=
backend_cls
mock_get_attn_backends
.
return_value
=
[
backend_cls
]
num_layers
=
32
block_size
=
16
num_blocks
=
8
num_heads
=
4
head_size
=
16
# TODO (NickLucche) the fact that connector depends on kv_cache_config for init
# but cross-layer preference cant be inferred prior to creating kv_cache_config
# is a bit awkward.
dummy_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
block_size
),
)
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_heads
,
head_size
=
head_size
,
dtype
=
torch
.
float16
,
)
if
dummy_connector
.
prefer_cross_layer_blocks
:
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
kv_cache_spec
.
page_size_bytes
*
num_blocks
,
shared_by
=
[
"all-layers"
],
)
for
_
in
range
(
num_layers
)
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"all-layers"
],
kv_cache_spec
)],
)
else
:
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer0"
,
"layer1"
,
"layer2"
],
kv_cache_spec
)
],
)
# Create connector
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
...
...
@@ -1526,35 +1585,6 @@ def test_register_kv_caches(
or
connector
.
prefer_cross_layer_blocks
)
if
connector
.
prefer_cross_layer_blocks
:
num_layers
=
32
block_size
=
16
num_blocks
=
8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
num_blocks
=
num_blocks
)
connector
.
connector_worker
.
kv_cache_config
=
worker_kv_cache_config
connector
.
connector_worker
.
num_blocks
=
worker_kv_cache_config
.
num_blocks
kv_cache_spec
=
AttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
64
,
dtype
=
torch
.
bfloat16
,
)
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
kv_cache_spec
.
page_size_bytes
*
num_blocks
,
shared_by
=
[
"dummy-layer"
],
)
for
i
in
range
(
num_layers
)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups
=
[],
)
with
set_current_vllm_config
(
vllm_config
):
_
,
cross_layers_kv_cache
,
_
=
(
KVConnectorModelRunnerMixin
.
allocate_uniform_kv_caches
(
...
...
@@ -1586,12 +1616,8 @@ def test_register_kv_caches(
expected_blocks_count
=
8
kv_caches
=
{
"all-layers"
:
cross_layers_kv_cache
}
else
:
# Create test kv cache tensors using proper backend shape
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
...
...
@@ -2261,7 +2287,7 @@ def test_compatibility_hash_validation(
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
decode_worker
.
attn_backend
s
[
0
]
.
get_kv_cache_shape
(
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
...
...
@@ -2269,10 +2295,14 @@ def test_compatibility_hash_validation(
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
# Build kv_caches from the actual layer names in kv_cache_config so that
# _layer_specs lookups in register_kv_caches always find a matching key.
layer_names
=
[
name
for
group
in
kv_cache_config
.
kv_cache_groups
for
name
in
group
.
layer_names
]
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
name
:
shared_tensor
if
i
%
2
==
0
else
unique_tensor
for
i
,
name
in
enumerate
(
layer_names
)
}
decode_connector
.
register_kv_caches
(
kv_caches
)
...
...
@@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation(
block_lens
=
[
4096
*
prefill_block_size
],
# slot_size * block_size
kv_cache_layout
=
"HND"
,
block_size
=
prefill_block_size
,
ssm_sizes
=
(
0
,
0
),
)
handshake_payload
=
NixlHandshakePayload
(
compatibility_hash
=
remote_hash
,
...
...
@@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
remote_block_size
=
decode_worker
.
_block_size
,
# shared state
is_mla
=
decode_worker
.
use_mla
,
total_num_kv_heads
=
decode_worker
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
backend
,
attn_backend
s
=
[
backend
]
,
tensor_shape
=
test_shape
,
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
View file @
f5c081d4
...
...
@@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma():
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker
.
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
hma_enabled
=
True
)
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
...
...
@@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure():
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
@
pytest
.
mark
.
cpu_test
def
test_get_block_descs_ids_hybrid_ssm
():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch)."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
num_blocks
=
100
engine_id
=
"test-engine"
worker
.
num_regions
=
2
worker
.
dst_num_blocks
=
{
engine_id
:
num_blocks
}
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
1
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker
.
num_descs
=
2
*
num_blocks
fa_blocks
=
[
3
,
5
]
ssm_blocks
=
[
1
,
2
]
result
=
worker
.
_get_block_descs_ids
(
engine_id
,
(
fa_blocks
,
ssm_blocks
))
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected
=
[
3
,
5
,
103
,
105
,
201
,
202
,
301
,
302
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
@
pytest
.
mark
.
cpu_test
def
test_get_block_descs_ids_kernel_block_mismatch
():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
ratio
=
4
logical_blocks
=
100
num_blocks
=
logical_blocks
*
ratio
# 400 kernel blocks
engine_id
=
"test-engine"
worker
.
num_regions
=
2
worker
.
dst_num_blocks
=
{
engine_id
:
num_blocks
}
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
ratio
worker
.
num_descs
=
2
*
num_blocks
# 800
fa_blocks
=
[
3
,
7
]
# kernel-level block IDs
ssm_blocks
=
[
1
,
2
]
# logical block IDs
result
=
worker
.
_get_block_descs_ids
(
engine_id
,
(
fa_blocks
,
ssm_blocks
))
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected
=
[
3
,
7
,
403
,
407
,
801
,
802
,
901
,
902
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
@
pytest
.
mark
.
cpu_test
def
test_nixl_metadata_hybrid_ssm_block_ids
():
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
groups with different block counts (kernel mismatch active)."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
,
)
metadata
=
NixlConnectorMetadata
()
# FA: 8 kernel blocks (2 logical * ratio=4), SSM: 2 logical blocks
fa_blocks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
ssm_blocks
=
[
0
,
1
]
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req-hybrid"
,
local_block_ids
=
(
fa_blocks
,
ssm_blocks
),
kv_transfer_params
=
{
"remote_block_ids"
:
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
20
,
21
]),
"remote_engine_id"
:
"remote-engine"
,
"remote_request_id"
:
"prefill-test-req-hybrid"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
assert
"test-req-hybrid"
in
metadata
.
reqs_to_recv
req_meta
=
metadata
.
reqs_to_recv
[
"test-req-hybrid"
]
# Verify local block IDs: different lengths per group
assert
len
(
req_meta
.
local_block_ids
)
==
2
assert
list
(
req_meta
.
local_block_ids
[
0
])
==
fa_blocks
assert
list
(
req_meta
.
local_block_ids
[
1
])
==
ssm_blocks
assert
len
(
req_meta
.
local_block_ids
[
0
])
!=
len
(
req_meta
.
local_block_ids
[
1
])
# Verify remote block IDs: same asymmetry preserved
assert
req_meta
.
remote
is
not
None
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
20
,
21
]
assert
len
(
req_meta
.
remote
.
block_ids
[
0
])
!=
len
(
req_meta
.
remote
.
block_ids
[
1
])
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
f5c081d4
...
...
@@ -16,10 +16,12 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_cache_interface
import
MambaSpec
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
logger
=
init_logger
(
__name__
)
...
...
@@ -328,22 +330,26 @@ class TpKVTopology:
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
attn_backend
s
:
list
[
type
[
AttentionBackend
]
]
engine_id
:
EngineId
remote_block_size
:
dict
[
EngineId
,
int
]
tensor_shape
:
torch
.
Size
|
None
=
None
is_mamba
:
bool
=
False
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
_MOCK_BLOCK_SIZE
=
16
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
_MOCK_BLOCK_SIZE
,
num_kv_heads
=
1
,
head_size
=
1
)
logger
.
debug
(
"Test kv_cache_shape: %s"
,
kv_cache_shape
)
attn_backend
=
self
.
attn_backends
[
0
]
if
not
self
.
is_mamba
:
_MOCK_BLOCK_SIZE
=
16
kv_cache_shape
:
tuple
[
int
,
...]
=
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
_MOCK_BLOCK_SIZE
,
num_kv_heads
=
1
,
head_size
=
1
)
logger
.
debug
(
"Test kv_cache_shape: %s"
,
kv_cache_shape
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self
.
_is_kv_layout_blocks_first
=
(
# Hybrid SSM models assume a single blocks_first layout
self
.
_is_kv_layout_blocks_first
=
self
.
is_mamba
or
(
len
(
kv_cache_shape
)
==
5
and
kv_cache_shape
[
0
]
==
1
)
...
...
@@ -360,7 +366,7 @@ class TpKVTopology:
_MOCK_NUM_LAYERS
=
80
kv_cache_shape
=
(
_MOCK_NUM_LAYERS
,)
+
kv_cache_shape
try
:
kv_cache_stride_order
=
self
.
attn_backend
.
get_kv_cache_stride_order
(
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
(
include_num_layers_dimension
=
self
.
_cross_layers_blocks
)
except
(
AttributeError
,
NotImplementedError
):
...
...
@@ -483,6 +489,30 @@ class TpKVTopology:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_ranks
(
remote_tp_size
)
def
get_transfer_cache_regions
(
self
,
cache
:
torch
.
Tensor
,
layer_spec
:
"KVCacheSpec"
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
:
"""Return the cache tensor(s) to register as NIXL memory regions,
also accounting for hybrid SSM models specificities.
"""
if
isinstance
(
layer_spec
,
MambaSpec
):
# Register the whole kv cache shared tensor, including SSM/Conv. This is
# similar to FI with the difference that SSM/Conv have different sizes
conv
,
ssm
=
cache
return
[
conv
]
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`.
if
self
.
is_mamba
and
cache
.
shape
[
0
]
==
2
:
# When MAMBA is present, all backends are blocks first, so that blocks
# can be shared between attention layers and mamba layers. Runner
# `_update_hybrid_attention_mamba_layout` already adjusted strides
# for FlashAttn-like backends so its num_blocks first.
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM.
cache
=
cache
.
transpose
(
0
,
1
)
# Regular case: backends like FA register K/V in separate regions
return
cache
if
self
.
split_k_and_v
else
[
cache
]
def
get_current_attn_backends
(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
...
...
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
View file @
f5c081d4
...
...
@@ -564,7 +564,7 @@ class MooncakeConnectorWorker:
remote_block_size
=
self
.
_block_size
,
# shared state
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
backend
,
attn_backend
s
=
[
backend
]
,
)
self
.
async_zmq_ctx
=
zmq
.
asyncio
.
Context
()
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
f5c081d4
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment