Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5c081d4
Unverified
Commit
f5c081d4
authored
Mar 16, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 16, 2026
Browse files
[PD][Nixl] Add support for hybrid SSM-FA models (#36687)
parent
c88ea833
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
587 additions
and
166 deletions
+587
-166
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
..._connector/nixl_integration/config_sweep_accuracy_test.sh
+8
-0
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+76
-45
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
+112
-0
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+38
-8
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
...v_transfer/kv_connector/v1/mooncake/mooncake_connector.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+351
-112
No files found.
tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
View file @
f5c081d4
...
@@ -18,11 +18,19 @@ dp_ep_configs=(
...
@@ -18,11 +18,19 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
)
)
hybrid_ssm_configs
=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
# Select config array based on DP_EP env var
# Select config array based on DP_EP env var
if
[[
-n
"
${
DP_EP
:-}
"
]]
;
then
if
[[
-n
"
${
DP_EP
:-}
"
]]
;
then
configs
=(
"
${
dp_ep_configs
[@]
}
"
)
configs
=(
"
${
dp_ep_configs
[@]
}
"
)
echo
"DP_EP is set, using dp_ep_configs"
echo
"DP_EP is set, using dp_ep_configs"
elif
[[
-n
"
${
HYBRID_SSM
:-}
"
]]
;
then
configs
=(
"
${
hybrid_ssm_configs
[@]
}
"
)
echo
"HYBRID_SSM is set, using hybrid_ssm_configs."
else
else
configs
=(
"
${
tp_configs
[@]
}
"
)
configs
=(
"
${
tp_configs
[@]
}
"
)
fi
fi
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
f5c081d4
...
@@ -18,6 +18,7 @@ EXPECTED_VALUES = {
...
@@ -18,6 +18,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/deepseek-vl2-tiny"
:
0.19
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
0.65
,
"google/gemma-3-4b-it"
:
0.74
,
"google/gemma-3-4b-it"
:
0.74
,
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8"
:
0.84
,
}
}
SIMPLE_PROMPT
=
(
SIMPLE_PROMPT
=
(
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
f5c081d4
...
@@ -53,7 +53,13 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
...
@@ -53,7 +53,13 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from
vllm.v1.attention.backends.utils
import
set_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
set_kv_cache_layout
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
,
KVCacheTensor
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
...
@@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_groups
=
[
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
KVCacheGroupSpec
(
[
"layer0"
,
"layer1"
,
"layer2"
],
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
16
,
dtype
=
torch
.
float16
,
),
)
]
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
2
,
kv_cache_tensors
=
[],
kv_cache_groups
=
kv_cache_groups
)
prefill_connector
=
NixlConnector
(
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
)
...
@@ -437,7 +455,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -437,7 +455,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self
.
kv_cache_layout
=
kv_cache_layout
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
# Mock register_kv_caches attribute needed for tests that do not call it.
self
.
src_xfer_handles_by_block_size
=
{
self
.
block_size
:
1
}
self
.
src_xfer_handles_by_block_size
=
{
self
.
block_size
:
1
}
test_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
test_shape
=
self
.
attn_backend
s
[
0
]
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
)
self
.
kv_topo
=
TpKVTopology
(
self
.
kv_topo
=
TpKVTopology
(
...
@@ -447,7 +465,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -447,7 +465,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
remote_block_size
=
self
.
_block_size
,
# shared state
remote_block_size
=
self
.
_block_size
,
# shared state
is_mla
=
self
.
use_mla
,
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
self
.
attn_backend
,
attn_backend
s
=
self
.
attn_backend
s
,
tensor_shape
=
test_shape
,
tensor_shape
=
test_shape
,
)
)
...
@@ -501,6 +519,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -501,6 +519,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# is started. We mock HND here.
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
ssm_sizes
=
(
0
,
0
),
),
),
remote_tp_rank
=
remote_tp_rank
,
remote_tp_rank
=
remote_tp_rank
,
remote_tp_size
=
remote_tp_size
,
remote_tp_size
=
remote_tp_size
,
...
@@ -951,6 +970,7 @@ class TestNixlHandshake:
...
@@ -951,6 +970,7 @@ class TestNixlHandshake:
block_lens
=
worker
.
block_len_per_layer
,
block_lens
=
worker
.
block_len_per_layer
,
kv_cache_layout
=
mismatched_layout
,
kv_cache_layout
=
mismatched_layout
,
block_size
=
worker
.
block_size
,
block_size
=
worker
.
block_size
,
ssm_sizes
=
(
0
,
0
),
)
)
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
RuntimeError
):
...
@@ -1006,6 +1026,7 @@ class TestNixlHandshake:
...
@@ -1006,6 +1026,7 @@ class TestNixlHandshake:
block_lens
=
[
i
*
2
for
i
in
worker
.
block_len_per_layer
],
block_lens
=
[
i
*
2
for
i
in
worker
.
block_len_per_layer
],
kv_cache_layout
=
"HND"
,
kv_cache_layout
=
"HND"
,
block_size
=
worker
.
block_size
,
block_size
=
worker
.
block_size
,
ssm_sizes
=
(
0
,
0
),
)
)
# We don't check layout for homogeneous TP and MLA for now, as the
# We don't check layout for homogeneous TP and MLA for now, as the
...
@@ -1496,9 +1517,47 @@ def test_register_kv_caches(
...
@@ -1496,9 +1517,47 @@ def test_register_kv_caches(
# test run if not mocking.
# test run if not mocking.
mock_get_attn_backend
.
return_value
=
backend_cls
mock_get_attn_backend
.
return_value
=
backend_cls
mock_get_attn_backends
.
return_value
=
[
backend_cls
]
mock_get_attn_backends
.
return_value
=
[
backend_cls
]
num_layers
=
32
block_size
=
16
num_blocks
=
8
num_heads
=
4
head_size
=
16
# TODO (NickLucche) the fact that connector depends on kv_cache_config for init
# but cross-layer preference cant be inferred prior to creating kv_cache_config
# is a bit awkward.
dummy_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
block_size
),
)
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_heads
,
head_size
=
head_size
,
dtype
=
torch
.
float16
,
)
if
dummy_connector
.
prefer_cross_layer_blocks
:
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
kv_cache_spec
.
page_size_bytes
*
num_blocks
,
shared_by
=
[
"all-layers"
],
)
for
_
in
range
(
num_layers
)
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"all-layers"
],
kv_cache_spec
)],
)
else
:
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer0"
,
"layer1"
,
"layer2"
],
kv_cache_spec
)
],
)
# Create connector
# Create connector
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
vllm_config
,
...
@@ -1526,35 +1585,6 @@ def test_register_kv_caches(
...
@@ -1526,35 +1585,6 @@ def test_register_kv_caches(
or
connector
.
prefer_cross_layer_blocks
or
connector
.
prefer_cross_layer_blocks
)
)
if
connector
.
prefer_cross_layer_blocks
:
if
connector
.
prefer_cross_layer_blocks
:
num_layers
=
32
block_size
=
16
num_blocks
=
8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
num_blocks
=
num_blocks
)
connector
.
connector_worker
.
kv_cache_config
=
worker_kv_cache_config
connector
.
connector_worker
.
num_blocks
=
worker_kv_cache_config
.
num_blocks
kv_cache_spec
=
AttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
head_size
=
64
,
dtype
=
torch
.
bfloat16
,
)
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
kv_cache_spec
.
page_size_bytes
*
num_blocks
,
shared_by
=
[
"dummy-layer"
],
)
for
i
in
range
(
num_layers
)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups
=
[],
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
_
,
cross_layers_kv_cache
,
_
=
(
_
,
cross_layers_kv_cache
,
_
=
(
KVConnectorModelRunnerMixin
.
allocate_uniform_kv_caches
(
KVConnectorModelRunnerMixin
.
allocate_uniform_kv_caches
(
...
@@ -1586,12 +1616,8 @@ def test_register_kv_caches(
...
@@ -1586,12 +1616,8 @@ def test_register_kv_caches(
expected_blocks_count
=
8
expected_blocks_count
=
8
kv_caches
=
{
"all-layers"
:
cross_layers_kv_cache
}
kv_caches
=
{
"all-layers"
:
cross_layers_kv_cache
}
else
:
else
:
# Create test kv cache tensors using proper backend shape
# Create test kv cache tensors using proper backend shape
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
kv_cache_config
.
num_blocks
,
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
block_size
=
kv_cache_spec
.
block_size
,
...
@@ -2261,7 +2287,7 @@ def test_compatibility_hash_validation(
...
@@ -2261,7 +2287,7 @@ def test_compatibility_hash_validation(
kv_cache_spec
=
cast
(
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
)
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
decode_worker
.
attn_backend
s
[
0
]
.
get_kv_cache_shape
(
num_blocks
=
kv_cache_config
.
num_blocks
,
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
block_size
=
kv_cache_spec
.
block_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
...
@@ -2269,10 +2295,14 @@ def test_compatibility_hash_validation(
...
@@ -2269,10 +2295,14 @@ def test_compatibility_hash_validation(
)
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
# Build kv_caches from the actual layer names in kv_cache_config so that
# _layer_specs lookups in register_kv_caches always find a matching key.
layer_names
=
[
name
for
group
in
kv_cache_config
.
kv_cache_groups
for
name
in
group
.
layer_names
]
kv_caches
=
{
kv_caches
=
{
"layer0"
:
shared_tensor
,
name
:
shared_tensor
if
i
%
2
==
0
else
unique_tensor
"layer1"
:
unique_tensor
,
for
i
,
name
in
enumerate
(
layer_names
)
"layer2"
:
shared_tensor
,
}
}
decode_connector
.
register_kv_caches
(
kv_caches
)
decode_connector
.
register_kv_caches
(
kv_caches
)
...
@@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation(
...
@@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation(
block_lens
=
[
4096
*
prefill_block_size
],
# slot_size * block_size
block_lens
=
[
4096
*
prefill_block_size
],
# slot_size * block_size
kv_cache_layout
=
"HND"
,
kv_cache_layout
=
"HND"
,
block_size
=
prefill_block_size
,
block_size
=
prefill_block_size
,
ssm_sizes
=
(
0
,
0
),
)
)
handshake_payload
=
NixlHandshakePayload
(
handshake_payload
=
NixlHandshakePayload
(
compatibility_hash
=
remote_hash
,
compatibility_hash
=
remote_hash
,
...
@@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
...
@@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
remote_block_size
=
decode_worker
.
_block_size
,
# shared state
remote_block_size
=
decode_worker
.
_block_size
,
# shared state
is_mla
=
decode_worker
.
use_mla
,
is_mla
=
decode_worker
.
use_mla
,
total_num_kv_heads
=
decode_worker
.
model_config
.
get_total_num_kv_heads
(),
total_num_kv_heads
=
decode_worker
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
backend
,
attn_backend
s
=
[
backend
]
,
tensor_shape
=
test_shape
,
tensor_shape
=
test_shape
,
)
)
...
...
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
View file @
f5c081d4
...
@@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma():
...
@@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma():
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker
.
_physical_blocks_per_logical_kv_block
=
2
worker
.
_physical_blocks_per_logical_kv_block
=
2
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker
.
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
hma_enabled
=
True
)
# Test conversion: FA + SW group
# Test conversion: FA + SW group
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
logical_block_ids
=
[[
0
,
1
,
2
],
[
3
,
4
]]
...
@@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure():
...
@@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure():
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
18
,
19
,
20
,
21
]
@
pytest
.
mark
.
cpu_test
def
test_get_block_descs_ids_hybrid_ssm
():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch)."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
num_blocks
=
100
engine_id
=
"test-engine"
worker
.
num_regions
=
2
worker
.
dst_num_blocks
=
{
engine_id
:
num_blocks
}
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
1
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker
.
num_descs
=
2
*
num_blocks
fa_blocks
=
[
3
,
5
]
ssm_blocks
=
[
1
,
2
]
result
=
worker
.
_get_block_descs_ids
(
engine_id
,
(
fa_blocks
,
ssm_blocks
))
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected
=
[
3
,
5
,
103
,
105
,
201
,
202
,
301
,
302
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
@
pytest
.
mark
.
cpu_test
def
test_get_block_descs_ids_kernel_block_mismatch
():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorWorker
,
)
worker
=
object
.
__new__
(
NixlConnectorWorker
)
ratio
=
4
logical_blocks
=
100
num_blocks
=
logical_blocks
*
ratio
# 400 kernel blocks
engine_id
=
"test-engine"
worker
.
num_regions
=
2
worker
.
dst_num_blocks
=
{
engine_id
:
num_blocks
}
worker
.
_has_mamba
=
True
worker
.
_is_mamba_group
=
[
False
,
True
]
worker
.
_physical_blocks_per_logical_kv_block
=
ratio
worker
.
num_descs
=
2
*
num_blocks
# 800
fa_blocks
=
[
3
,
7
]
# kernel-level block IDs
ssm_blocks
=
[
1
,
2
]
# logical block IDs
result
=
worker
.
_get_block_descs_ids
(
engine_id
,
(
fa_blocks
,
ssm_blocks
))
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected
=
[
3
,
7
,
403
,
407
,
801
,
802
,
901
,
902
]
assert
list
(
result
)
==
expected
,
f
"Expected
{
expected
}
, got
{
list
(
result
)
}
"
@
pytest
.
mark
.
cpu_test
def
test_nixl_metadata_hybrid_ssm_block_ids
():
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
groups with different block counts (kernel mismatch active)."""
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
,
)
metadata
=
NixlConnectorMetadata
()
# FA: 8 kernel blocks (2 logical * ratio=4), SSM: 2 logical blocks
fa_blocks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
ssm_blocks
=
[
0
,
1
]
metadata
.
add_new_req_to_recv
(
request_id
=
"test-req-hybrid"
,
local_block_ids
=
(
fa_blocks
,
ssm_blocks
),
kv_transfer_params
=
{
"remote_block_ids"
:
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
20
,
21
]),
"remote_engine_id"
:
"remote-engine"
,
"remote_request_id"
:
"prefill-test-req-hybrid"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"tp_size"
:
1
,
},
)
assert
"test-req-hybrid"
in
metadata
.
reqs_to_recv
req_meta
=
metadata
.
reqs_to_recv
[
"test-req-hybrid"
]
# Verify local block IDs: different lengths per group
assert
len
(
req_meta
.
local_block_ids
)
==
2
assert
list
(
req_meta
.
local_block_ids
[
0
])
==
fa_blocks
assert
list
(
req_meta
.
local_block_ids
[
1
])
==
ssm_blocks
assert
len
(
req_meta
.
local_block_ids
[
0
])
!=
len
(
req_meta
.
local_block_ids
[
1
])
# Verify remote block IDs: same asymmetry preserved
assert
req_meta
.
remote
is
not
None
assert
len
(
req_meta
.
remote
.
block_ids
)
==
2
assert
list
(
req_meta
.
remote
.
block_ids
[
0
])
==
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
]
assert
list
(
req_meta
.
remote
.
block_ids
[
1
])
==
[
20
,
21
]
assert
len
(
req_meta
.
remote
.
block_ids
[
0
])
!=
len
(
req_meta
.
remote
.
block_ids
[
1
])
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
f5c081d4
...
@@ -16,10 +16,12 @@ from vllm.logger import init_logger
...
@@ -16,10 +16,12 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_cache_interface
import
MambaSpec
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -328,22 +330,26 @@ class TpKVTopology:
...
@@ -328,22 +330,26 @@ class TpKVTopology:
remote_tp_size
:
dict
[
EngineId
,
int
]
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
is_mla
:
bool
total_num_kv_heads
:
int
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
attn_backend
s
:
list
[
type
[
AttentionBackend
]
]
engine_id
:
EngineId
engine_id
:
EngineId
remote_block_size
:
dict
[
EngineId
,
int
]
remote_block_size
:
dict
[
EngineId
,
int
]
tensor_shape
:
torch
.
Size
|
None
=
None
tensor_shape
:
torch
.
Size
|
None
=
None
is_mamba
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
# or num_blocks. This is used to register the memory regions correctly.
attn_backend
=
self
.
attn_backends
[
0
]
if
not
self
.
is_mamba
:
_MOCK_BLOCK_SIZE
=
16
_MOCK_BLOCK_SIZE
=
16
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
:
tuple
[
int
,
...]
=
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
_MOCK_BLOCK_SIZE
,
num_kv_heads
=
1
,
head_size
=
1
num_blocks
=
1
,
block_size
=
_MOCK_BLOCK_SIZE
,
num_kv_heads
=
1
,
head_size
=
1
)
)
logger
.
debug
(
"Test kv_cache_shape: %s"
,
kv_cache_shape
)
logger
.
debug
(
"Test kv_cache_shape: %s"
,
kv_cache_shape
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
# we just mock num_blocks to 1 for the dimension check below.
self
.
_is_kv_layout_blocks_first
=
(
# Hybrid SSM models assume a single blocks_first layout
self
.
_is_kv_layout_blocks_first
=
self
.
is_mamba
or
(
len
(
kv_cache_shape
)
==
5
and
kv_cache_shape
[
0
]
==
1
len
(
kv_cache_shape
)
==
5
and
kv_cache_shape
[
0
]
==
1
)
)
...
@@ -360,7 +366,7 @@ class TpKVTopology:
...
@@ -360,7 +366,7 @@ class TpKVTopology:
_MOCK_NUM_LAYERS
=
80
_MOCK_NUM_LAYERS
=
80
kv_cache_shape
=
(
_MOCK_NUM_LAYERS
,)
+
kv_cache_shape
kv_cache_shape
=
(
_MOCK_NUM_LAYERS
,)
+
kv_cache_shape
try
:
try
:
kv_cache_stride_order
=
self
.
attn_backend
.
get_kv_cache_stride_order
(
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
(
include_num_layers_dimension
=
self
.
_cross_layers_blocks
include_num_layers_dimension
=
self
.
_cross_layers_blocks
)
)
except
(
AttributeError
,
NotImplementedError
):
except
(
AttributeError
,
NotImplementedError
):
...
@@ -483,6 +489,30 @@ class TpKVTopology:
...
@@ -483,6 +489,30 @@ class TpKVTopology:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_ranks
(
remote_tp_size
)
return
self
.
get_target_remote_ranks
(
remote_tp_size
)
def
get_transfer_cache_regions
(
self
,
cache
:
torch
.
Tensor
,
layer_spec
:
"KVCacheSpec"
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
:
"""Return the cache tensor(s) to register as NIXL memory regions,
also accounting for hybrid SSM models specificities.
"""
if
isinstance
(
layer_spec
,
MambaSpec
):
# Register the whole kv cache shared tensor, including SSM/Conv. This is
# similar to FI with the difference that SSM/Conv have different sizes
conv
,
ssm
=
cache
return
[
conv
]
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`.
if
self
.
is_mamba
and
cache
.
shape
[
0
]
==
2
:
# When MAMBA is present, all backends are blocks first, so that blocks
# can be shared between attention layers and mamba layers. Runner
# `_update_hybrid_attention_mamba_layout` already adjusted strides
# for FlashAttn-like backends so its num_blocks first.
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM.
cache
=
cache
.
transpose
(
0
,
1
)
# Regular case: backends like FA register K/V in separate regions
return
cache
if
self
.
split_k_and_v
else
[
cache
]
def
get_current_attn_backends
(
def
get_current_attn_backends
(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
...
...
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
View file @
f5c081d4
...
@@ -564,7 +564,7 @@ class MooncakeConnectorWorker:
...
@@ -564,7 +564,7 @@ class MooncakeConnectorWorker:
remote_block_size
=
self
.
_block_size
,
# shared state
remote_block_size
=
self
.
_block_size
,
# shared state
is_mla
=
self
.
use_mla
,
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
backend
,
attn_backend
s
=
[
backend
]
,
)
)
self
.
async_zmq_ctx
=
zmq
.
asyncio
.
Context
()
self
.
async_zmq_ctx
=
zmq
.
asyncio
.
Context
()
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
f5c081d4
...
@@ -59,7 +59,12 @@ from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
...
@@ -59,7 +59,12 @@ from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
MambaSpec
,
SlidingWindowSpec
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
MambaSpec
,
SlidingWindowSpec
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.utils
import
select_common_block_size
from
vllm.v1.worker.utils
import
select_common_block_size
...
@@ -159,6 +164,7 @@ class NixlAgentMetadata:
...
@@ -159,6 +164,7 @@ class NixlAgentMetadata:
block_lens
:
list
[
int
]
block_lens
:
list
[
int
]
kv_cache_layout
:
str
kv_cache_layout
:
str
block_size
:
int
block_size
:
int
ssm_sizes
:
tuple
[
int
,
int
]
@
dataclass
@
dataclass
...
@@ -310,6 +316,15 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -310,6 +316,15 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class
NixlConnector
(
KVConnectorBase_V1
,
SupportsHMA
):
class
NixlConnector
(
KVConnectorBase_V1
,
SupportsHMA
):
@
property
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
if
any
(
[
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
)
for
group
in
self
.
kv_cache_config
.
kv_cache_groups
]
):
# Hybrid SSM models do not yet support cross-layer layout
return
False
backend
=
get_current_attn_backend
(
self
.
_vllm_config
)
backend
=
get_current_attn_backend
(
self
.
_vllm_config
)
if
backend
.
get_name
()
not
in
(
if
backend
.
get_name
()
not
in
(
"FLASH_ATTN"
,
"FLASH_ATTN"
,
...
@@ -335,12 +350,9 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
...
@@ -335,12 +350,9 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
kv_cache_config
:
"KVCacheConfig"
,
kv_cache_config
:
"KVCacheConfig"
,
):
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
for
group
in
kv_cache_config
.
kv_cache_groups
:
self
.
kv_cache_config
=
kv_cache_config
if
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
):
raise
ValueError
(
"NixlConnector does not support Mamba models."
)
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
if
role
==
KVConnectorRole
.
SCHEDULER
:
if
role
==
KVConnectorRole
.
SCHEDULER
:
...
@@ -434,11 +446,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
...
@@ -434,11 +446,7 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA):
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
):
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
self
.
connector_worker
.
register_cross_layers_kv_caches
(
kv_cache
)
cross_layer_name
=
"ALL_LAYERS"
kv_caches
=
{
cross_layer_name
:
kv_cache
}
self
.
connector_worker
.
register_kv_caches
(
kv_caches
)
def
set_host_xfer_buffer_ops
(
self
,
copy_operation
:
CopyBlocksOp
):
def
set_host_xfer_buffer_ops
(
self
,
copy_operation
:
CopyBlocksOp
):
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
...
@@ -962,6 +970,40 @@ class NixlConnectorWorker:
...
@@ -962,6 +970,40 @@ class NixlConnectorWorker:
)
)
)
)
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
self
.
_layer_specs
=
{
layer
:
group
.
kv_cache_spec
for
group
in
kv_cache_config
.
kv_cache_groups
for
layer
in
group
.
layer_names
}
self
.
hma_group_size
=
len
(
kv_cache_config
.
kv_cache_tensors
)
# Mamba metadata
self
.
_is_mamba_group
=
[
isinstance
(
group
.
kv_cache_spec
,
MambaSpec
)
for
group
in
kv_cache_config
.
kv_cache_groups
]
mamba_ssm_size
=
(
0
,
0
)
self
.
_has_mamba
=
any
(
self
.
_is_mamba_group
)
if
self
.
_has_mamba
:
assert
self
.
_is_hma_required
mamba_spec
=
next
(
spec
for
spec
in
self
.
_layer_specs
.
values
()
if
isinstance
(
spec
,
MambaSpec
)
)
conv_nbytes
,
ssm_nbytes
=
(
torch
.
tensor
([],
dtype
=
mamba_spec
.
dtypes
[
0
]).
element_size
(),
# type: ignore[misc]
torch
.
tensor
([],
dtype
=
mamba_spec
.
dtypes
[
1
]).
element_size
(),
# type: ignore[misc]
)
conv_shape
,
ssm_shape
=
(
torch
.
Size
(
mamba_spec
.
shapes
[
0
]),
torch
.
Size
(
mamba_spec
.
shapes
[
1
]),
)
mamba_ssm_size
=
(
conv_shape
.
numel
()
*
conv_nbytes
,
ssm_shape
.
numel
()
*
ssm_nbytes
,
)
self
.
_mamba_ssm_size
=
mamba_ssm_size
# Agent.
# Agent.
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
...
@@ -1106,9 +1148,9 @@ class NixlConnectorWorker:
...
@@ -1106,9 +1148,9 @@ class NixlConnectorWorker:
# Get the attention backend from the first layer
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
# NOTE (NickLucche) models with multiple backends are not supported yet
self
.
attn_backend
=
get_current_attn_backend
(
vllm_config
)
self
.
attn_backends
=
get_current_attn_backends
(
vllm_config
)
self
.
backend_name
=
self
.
attn_backends
[
0
].
get_name
()
self
.
backend_name
=
self
.
attn_backend
.
get_name
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
...
@@ -1135,6 +1177,8 @@ class NixlConnectorWorker:
...
@@ -1135,6 +1177,8 @@ class NixlConnectorWorker:
def
_sync_block_size_with_kernel
(
self
)
->
None
:
def
_sync_block_size_with_kernel
(
self
)
->
None
:
backends
=
get_current_attn_backends
(
self
.
vllm_config
)
backends
=
get_current_attn_backends
(
self
.
vllm_config
)
kernel_block_size
=
select_common_block_size
(
self
.
block_size
,
backends
)
kernel_block_size
=
select_common_block_size
(
self
.
block_size
,
backends
)
# Number of blocks not accounting for kernel block mismatches
self
.
_logical_num_blocks
=
self
.
num_blocks
if
self
.
block_size
!=
kernel_block_size
:
if
self
.
block_size
!=
kernel_block_size
:
logger
.
info_once
(
logger
.
info_once
(
"User-specified logical block size (%s) does not match"
"User-specified logical block size (%s) does not match"
...
@@ -1428,9 +1472,19 @@ class NixlConnectorWorker:
...
@@ -1428,9 +1472,19 @@ class NixlConnectorWorker:
fut
.
add_done_callback
(
request_ready
)
fut
.
add_done_callback
(
request_ready
)
def
register_cross_layers_kv_caches
(
self
,
kv_cache
:
torch
.
Tensor
)
->
None
:
"""Register a cross-layers KV cache tensor with NIXL.
`use_uniform_kv_cache()` guarantees a single KV cache group whose
layers all share the same `AttentionSpec`, so any layer name from
`_layer_specs` yields the correct per-layer spec for `page_size_bytes`.
"""
first_layer
=
next
(
iter
(
self
.
_layer_specs
))
# Forwarding a real layer name rather than a synthetic key
self
.
register_kv_caches
({
first_layer
:
kv_cache
})
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in nixl."""
"""Register the KV Cache data in nixl."""
self
.
kv_topo
=
TpKVTopology
(
self
.
kv_topo
=
TpKVTopology
(
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
...
@@ -1438,8 +1492,12 @@ class NixlConnectorWorker:
...
@@ -1438,8 +1492,12 @@ class NixlConnectorWorker:
remote_block_size
=
self
.
_block_size
,
# shared state
remote_block_size
=
self
.
_block_size
,
# shared state
is_mla
=
self
.
use_mla
,
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
self
.
attn_backend
,
attn_backends
=
self
.
attn_backends
,
tensor_shape
=
next
(
iter
(
kv_caches
.
values
())).
shape
,
# SSM States come in tuples (ssm, conv)
tensor_shape
=
next
(
iter
(
kv_caches
.
values
())).
shape
if
not
self
.
_has_mamba
else
None
,
is_mamba
=
self
.
_has_mamba
,
)
)
self
.
compat_hash
=
compute_nixl_compatibility_hash
(
self
.
compat_hash
=
compute_nixl_compatibility_hash
(
self
.
vllm_config
,
self
.
backend_name
,
self
.
kv_topo
.
cross_layers_blocks
self
.
vllm_config
,
self
.
backend_name
,
self
.
kv_topo
.
cross_layers_blocks
...
@@ -1481,12 +1539,50 @@ class NixlConnectorWorker:
...
@@ -1481,12 +1539,50 @@ class NixlConnectorWorker:
# to better exploit the memory layout (ie num_blocks is the first dim).
# to better exploit the memory layout (ie num_blocks is the first dim).
tensor_size_bytes
=
None
tensor_size_bytes
=
None
# Enable different block lengths for different layers when MLA is used.
# Enable different block lengths for different layers *only* when MLA is used.
# This is not used for SSM layers, which use the counterpart `mamba_ssm_size`.
self
.
block_len_per_layer
=
list
[
int
]()
self
.
block_len_per_layer
=
list
[
int
]()
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
cache_list
=
(
# NOTE (NickLucche) Hybrid SSM models assume a layout that is similar to
cache_or_caches
if
self
.
kv_topo
.
split_k_and_v
else
[
cache_or_caches
]
# that of FI, with block laid out as in `get_backend_aware_kv_block_len`.
)
# However, physical page_size may differ when kernel requires a specific
# block size. This leads to SSM and FA layers having different num_blocks.
# `_physical_blocks_per_logical_kv_block` ratio is used to adjust for this.
layer_spec
=
self
.
_layer_specs
[
layer_name
]
if
isinstance
(
layer_spec
,
UniformTypeKVCacheSpecs
):
# MLA DSv32 Indexer case: UniformTypeKVCacheSpecs merges kv_cache_specs
layer_spec
=
layer_spec
.
kv_cache_specs
[
layer_name
]
cache_list
=
self
.
kv_topo
.
get_transfer_cache_regions
(
cache_or_caches
,
layer_spec
)
# `layer_spec.page_size_bytes` only accounts for logical page_size, that is
# the page_size assuming constant `self._logical_num_blocks`.
physical_page_size
=
(
layer_spec
.
page_size_bytes
if
isinstance
(
layer_spec
,
MambaSpec
)
else
layer_spec
.
page_size_bytes
//
self
.
_physical_blocks_per_logical_kv_block
)
# For when registering multiple tensors eg K/V in separate regions.
physical_page_size
=
physical_page_size
//
len
(
cache_list
)
if
self
.
kv_topo
.
_cross_layers_blocks
:
# When cross-layers blocks are used, multiply by number of layers
physical_page_size
=
physical_page_size
*
len
(
self
.
kv_cache_config
.
kv_cache_tensors
)
num_blocks
=
(
self
.
_logical_num_blocks
if
isinstance
(
layer_spec
,
MambaSpec
)
else
self
.
num_blocks
)
# `page_size` accounts for physical blocks, st KVCache is always
# [`num_blocks` * `page_size`]
curr_tensor_size_bytes
=
num_blocks
*
physical_page_size
if
tensor_size_bytes
is
None
:
tensor_size_bytes
=
curr_tensor_size_bytes
# TODO (NickLucche) we could eventually unify how we handle FA/FI regions,
# registering a single tensor for both K/V and splitting logically like FI.
for
cache
in
cache_list
:
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
base_addr
=
cache
.
data_ptr
()
if
base_addr
in
seen_base_addresses
:
if
base_addr
in
seen_base_addresses
:
...
@@ -1494,27 +1590,27 @@ class NixlConnectorWorker:
...
@@ -1494,27 +1590,27 @@ class NixlConnectorWorker:
# across groups. This results in skipping all tensors but the ones
# across groups. This results in skipping all tensors but the ones
# pointed to by group0. Also, generally we will have more blocks
# pointed to by group0. Also, generally we will have more blocks
# per tensor but fewer regions.
# per tensor but fewer regions.
logger
.
debug
(
"Skipping %s because it's already seen"
,
layer_name
)
continue
continue
logger
.
debug
(
logger
.
debug
(
"Registering layer %s with cache shape: %s"
,
layer_name
,
cache
.
shape
"Registering layer %s with cache shape: %s"
,
layer_name
,
cache
.
shape
)
)
seen_base_addresses
.
append
(
base_addr
)
seen_base_addresses
.
append
(
base_addr
)
curr_tensor_size_bytes
=
cache
.
numel
()
*
cache
.
element_size
()
# Only record non-Mamba page sizes.
if
isinstance
(
layer_spec
,
MambaSpec
):
if
tensor_size_bytes
is
None
:
self
.
block_len_per_layer
.
append
(
tensor_size_bytes
=
curr_tensor_size_bytes
physical_page_size
//
self
.
_physical_blocks_per_logical_kv_block
assert
cache
.
shape
[
0
]
==
self
.
num_blocks
,
(
"All kv cache tensors must have the same number of blocks"
)
)
else
:
self
.
block_len_per_layer
.
append
(
physical_page_size
)
self
.
block_len_per_layer
.
append
(
assert
cache
.
shape
[
0
]
==
num_blocks
,
(
curr_tensor_size_bytes
//
self
.
num_
blocks
"All kv cache tensors must have the same number of
blocks
"
)
)
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Different kv cache shape is not supported by HeteroTP
# Different kv cache shape is not supported by HeteroTP.
# This must also hold true for Mamba-like models.
assert
tensor_size_bytes
==
curr_tensor_size_bytes
,
(
assert
tensor_size_bytes
==
curr_tensor_size_bytes
,
(
"All kv cache tensors must have the same size"
"All kv cache tensors must have the same size"
)
)
...
@@ -1533,6 +1629,21 @@ class NixlConnectorWorker:
...
@@ -1533,6 +1629,21 @@ class NixlConnectorWorker:
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
]
=
seen_base_addresses
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
self
.
num_regions
=
len
(
caches_data
)
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence we double the number
# of 'virtual' regions here and halve `block_len` below.
# Similarly for Mamba layers, we register SSM+Conv as a single region and
# then duplicate it logically to be able to index SSM/Conv separately.
self
.
num_regions
*=
2
# TODO (NickLucche) Adapt to different descs views (engine_id->tp_rank) to
# support heterogeneous TP.
self
.
num_descs
=
self
.
num_regions
*
self
.
num_blocks
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
self
.
nixl_memory_type
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
self
.
nixl_wrapper
.
register_memory
(
descs
,
backends
=
self
.
nixl_backends
)
self
.
nixl_wrapper
.
register_memory
(
descs
,
backends
=
self
.
nixl_backends
)
...
@@ -1542,17 +1653,21 @@ class NixlConnectorWorker:
...
@@ -1542,17 +1653,21 @@ class NixlConnectorWorker:
self
.
device_kv_caches
=
kv_caches
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
self
.
_has_mamba
:
# NOTE (NickLucche) When FlashInfer is used, memory is registered
logger
.
info
(
# with joint KV for each block. This minimizes the overhead in
"Hybrid SSM registration: num_blocks=%s, "
# registerMem allowing faster descs queries. In order to be able to
"logical_num_blocks=%s, ratio=%s, num_regions=%s, "
# split on kv_heads dim as required by heterogeneous TP, one must
"num_descs=%s, mamba_ssm_size=%s, block_len_per_layer=%s"
,
# be able to index K/V separately. Hence we double the number
self
.
num_blocks
,
# of 'virtual' regions here and halve `block_len` below.
self
.
_logical_num_blocks
,
self
.
num_regions
*=
2
self
.
_physical_blocks_per_logical_kv_block
,
self
.
num_regions
,
self
.
num_descs
,
self
.
_mamba_ssm_size
,
set
(
self
.
block_len_per_layer
),
)
# Register local/src descr for NIXL xfer.
# Register local/src descr for NIXL xfer.
self
.
seen_base_addresses
=
seen_base_addresses
self
.
src_xfer_handles_by_block_size
[
self
.
block_size
],
self
.
src_blocks_data
=
(
self
.
src_xfer_handles_by_block_size
[
self
.
block_size
],
self
.
src_blocks_data
=
(
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
register_local_xfer_handler
(
self
.
block_size
)
)
)
...
@@ -1569,6 +1684,7 @@ class NixlConnectorWorker:
...
@@ -1569,6 +1684,7 @@ class NixlConnectorWorker:
if
not
self
.
use_host_buffer
if
not
self
.
use_host_buffer
else
self
.
host_buffer_kv_cache_layout
,
else
self
.
host_buffer_kv_cache_layout
,
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
ssm_sizes
=
self
.
_mamba_ssm_size
,
)
)
# Wrap metadata in payload with hash for defensive decoding
# Wrap metadata in payload with hash for defensive decoding
assert
self
.
compat_hash
is
not
None
assert
self
.
compat_hash
is
not
None
...
@@ -1594,24 +1710,41 @@ class NixlConnectorWorker:
...
@@ -1594,24 +1710,41 @@ class NixlConnectorWorker:
data copy correctness.
data copy correctness.
"""
"""
assert
self
.
kv_topo
is
not
None
assert
self
.
kv_topo
is
not
None
kv_topo
=
self
.
kv_topo
block_size_ratio
=
self
.
block_size
//
block_size
block_size_ratio
=
self
.
block_size
//
block_size
blocks_data
=
[]
blocks_data
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
base_addr
in
enumerate
(
self
.
seen_base_addresses
):
local_base_addresses
=
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
]
def
register_blocks
(
blocks_data
:
list
[
tuple
[
int
,
int
,
int
]],
mamba
:
bool
):
for
i
,
base_addr
in
enumerate
(
local_base_addresses
):
# The new block_len is using prefill block_len;
# The new block_len is using prefill block_len;
# and num_blocks is multiple with N
# and num_blocks is multiple with N
kv_block_len
=
(
kv_block_len
=
(
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
//
block_size_ratio
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
,
first_split
=
True
,
mamba_view
=
mamba
)
)
block_len_per_layer
=
self
.
block_len_per_layer
[
i
]
//
block_size_ratio
//
block_size_ratio
num_blocks
=
self
.
num_blocks
*
block_size_ratio
)
# Jump one page_size, but ssm page_size may be bigger when kernel
# locks block size to a specific value.
block_len_per_layer
=
(
self
.
block_len_per_layer
[
i
]
//
block_size_ratio
*
(
1
if
not
mamba
else
self
.
_physical_blocks_per_logical_kv_block
)
)
num_blocks
=
self
.
_logical_num_blocks
if
mamba
else
self
.
num_blocks
num_blocks
=
num_blocks
*
block_size_ratio
for
block_id
in
range
(
num_blocks
):
for
block_id
in
range
(
num_blocks
):
block_offset
=
block_id
*
block_len_per_layer
block_offset
=
block_id
*
block_len_per_layer
addr
=
base_addr
+
block_offset
addr
=
base_addr
+
block_offset
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv_block_len
,
self
.
device_id
))
blocks_data
.
append
((
addr
,
kv_block_len
,
self
.
device_id
))
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
kv_topo
.
is_kv_layout_blocks_first
:
second_split
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
,
first_split
=
False
,
mamba_view
=
mamba
)
# Separate and interleave K/V regions to maintain the same
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
# when split across TP ranks.
...
@@ -1620,7 +1753,7 @@ class NixlConnectorWorker:
...
@@ -1620,7 +1753,7 @@ class NixlConnectorWorker:
addr
=
base_addr
+
block_offset
addr
=
base_addr
+
block_offset
# Register addresses for V cache (K registered first).
# Register addresses for V cache (K registered first).
v_addr
=
addr
+
kv_block_len
v_addr
=
addr
+
kv_block_len
blocks_data
.
append
((
v_addr
,
kv_block_len
,
self
.
device_id
))
blocks_data
.
append
((
v_addr
,
second_split
,
self
.
device_id
))
logger
.
debug
(
logger
.
debug
(
"Created %s blocks for src engine %s and rank %s on device id %s"
,
"Created %s blocks for src engine %s and rank %s on device id %s"
,
len
(
blocks_data
),
len
(
blocks_data
),
...
@@ -1629,6 +1762,14 @@ class NixlConnectorWorker:
...
@@ -1629,6 +1762,14 @@ class NixlConnectorWorker:
self
.
device_id
,
self
.
device_id
,
)
)
register_blocks
(
blocks_data
,
mamba
=
False
)
if
self
.
_has_mamba
:
assert
self
.
num_descs
==
len
(
blocks_data
)
logger
.
debug
(
"Registering additional %s local Mamba blocks"
,
len
(
blocks_data
)
)
register_blocks
(
blocks_data
,
mamba
=
True
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
# NIXL_INIT_AGENT to be used for preparations of local descs.
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
),
blocks_data
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
),
blocks_data
...
@@ -1708,7 +1849,8 @@ class NixlConnectorWorker:
...
@@ -1708,7 +1849,8 @@ class NixlConnectorWorker:
# local origin:| 0| 1| 8| 12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert
self
.
kv_topo
is
not
None
assert
self
.
kv_topo
is
not
None
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
engine_id
)
kv_topo
=
self
.
kv_topo
block_size_ratio
=
kv_topo
.
block_size_ratio_from_engine_id
(
engine_id
)
if
engine_id
not
in
self
.
dst_num_blocks
:
if
engine_id
not
in
self
.
dst_num_blocks
:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
...
@@ -1768,9 +1910,14 @@ class NixlConnectorWorker:
...
@@ -1768,9 +1910,14 @@ class NixlConnectorWorker:
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Register all remote blocks, but only the corresponding kv heads.
# Register all remote blocks, but only the corresponding kv heads.
def
register_remote_blocks
(
blocks_data
:
list
[
tuple
[
int
,
int
,
int
]],
mamba
:
bool
):
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
# Read our whole local region size from remote.
# Read our whole local region size from remote.
local_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
local_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
,
first_split
=
True
,
mamba_view
=
mamba
)
remote_kv_block_len
=
local_block_len
//
block_size_ratio
remote_kv_block_len
=
local_block_len
//
block_size_ratio
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
# using remote kv_block_len as transfer unit
# using remote kv_block_len as transfer unit
...
@@ -1784,33 +1931,66 @@ class NixlConnectorWorker:
...
@@ -1784,33 +1931,66 @@ class NixlConnectorWorker:
if
indexes_into_remote
if
indexes_into_remote
else
0
else
0
)
)
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
# Assume same num_blocks for mamba and fa
num_blocks
=
(
nixl_agent_meta
.
num_blocks
if
not
mamba
else
nixl_agent_meta
.
num_blocks
//
self
.
_physical_blocks_per_logical_kv_block
)
page_size
=
nixl_agent_meta
.
block_lens
[
i
]
*
(
1
if
not
mamba
else
self
.
_physical_blocks_per_logical_kv_block
)
for
block_id
in
range
(
num_blocks
):
block_offset
=
block_id
*
page_size
# For each block, grab the heads chunk belonging to rank_i
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
local_block_len
,
nixl_agent_meta
.
device_id
))
blocks_data
.
append
(
(
addr
,
local_block_len
,
nixl_agent_meta
.
device_id
)
)
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
kv_topo
.
is_kv_layout_blocks_first
:
# With FlashInfer index V separately to allow head splitting.
# With FlashInfer index V separately to allow head splitting.
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
second_split
=
self
.
get_backend_aware_kv_block_len
(
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
layer_idx
=
i
,
first_split
=
False
,
mamba_view
=
mamba
)
# Apply the same scaling as local_block_len above for when we read
# a chunk of local V from `tp_ratio` separate remote workers.
if
tp_ratio
<
0
and
not
self
.
use_mla
:
second_split
=
second_split
//
(
-
tp_ratio
)
for
block_id
in
range
(
num_blocks
):
block_offset
=
block_id
*
page_size
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
# Hop over the first split of remote page: either K or Conv.
if
mamba
:
v_addr
=
addr
+
nixl_agent_meta
.
ssm_sizes
[
0
]
else
:
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
blocks_data
.
append
(
blocks_data
.
append
(
(
v_addr
,
local_block_len
,
nixl_agent_meta
.
device_id
)
(
v_addr
,
second_split
,
nixl_agent_meta
.
device_id
)
)
)
logger
.
debug
(
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s"
,
"Created %s blocks for dst engine %s"
" with remote rank %s and local rank %s"
,
len
(
blocks_data
),
len
(
blocks_data
),
engine_id
,
engine_id
,
remote_tp_rank
,
remote_tp_rank
,
self
.
tp_rank
,
self
.
tp_rank
,
)
)
register_remote_blocks
(
blocks_data
,
mamba
=
False
)
if
self
.
_has_mamba
:
# Create extra descs for the Mamba "view" of the same KV cache tensors.
logger
.
debug
(
"Registering additional %s remote Mamba blocks"
,
len
(
blocks_data
)
)
register_remote_blocks
(
blocks_data
,
mamba
=
True
)
# Register with NIXL.
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
self
.
dst_xfer_side_handles
[
engine_id
][
remote_tp_rank
]
=
(
self
.
dst_xfer_side_handles
[
engine_id
][
remote_tp_rank
]
=
(
...
@@ -1849,6 +2029,9 @@ class NixlConnectorWorker:
...
@@ -1849,6 +2029,9 @@ class NixlConnectorWorker:
assert
block_size_ratio
==
1
,
(
assert
block_size_ratio
==
1
,
(
"HMA does not support different remote block size yet"
"HMA does not support different remote block size yet"
)
)
# Mamba additional constraints
if
self
.
_has_mamba
:
assert
tp_ratio
==
1
,
"Mamba does not support heterogeneous TP yet"
kv_cache_layout
=
(
kv_cache_layout
=
(
self
.
kv_cache_layout
self
.
kv_cache_layout
...
@@ -2495,6 +2678,7 @@ class NixlConnectorWorker:
...
@@ -2495,6 +2678,7 @@ class NixlConnectorWorker:
A single flattened array is returned for all groups anyway.
A single flattened array is returned for all groups anyway.
"""
"""
region_ids
=
np
.
arange
(
self
.
num_regions
)
region_ids
=
np
.
arange
(
self
.
num_regions
)
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
# layers from different groups share the same kv tensor.
# layers from different groups share the same kv tensor.
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
...
@@ -2505,11 +2689,33 @@ class NixlConnectorWorker:
...
@@ -2505,11 +2689,33 @@ class NixlConnectorWorker:
if
block_size_ratio
is
not
None
:
if
block_size_ratio
is
not
None
:
num_blocks
=
int
(
num_blocks
*
block_size_ratio
)
num_blocks
=
int
(
num_blocks
*
block_size_ratio
)
# Compute the desc ids for each block.
# Compute desc ids per group using the right stride: FA descs have
# num_blocks entries per region (kernel granularity), SSM descs have
# logical_blocks entries per region (no kernel splitting).
region_ids
=
region_ids
[:,
None
]
region_ids
=
region_ids
[:,
None
]
if
not
self
.
_has_mamba
:
block_ids
=
np
.
concatenate
(
block_ids
)[
None
,
:]
block_ids
=
np
.
concatenate
(
block_ids
)[
None
,
:]
descs_ids
=
region_ids
*
num_blocks
+
block_ids
descs_ids
=
region_ids
*
num_blocks
+
block_ids
return
descs_ids
.
flatten
()
return
descs_ids
.
flatten
()
else
:
# NOTE (NickLucche) SSM and Attention blocks regions can be exchanged
# arbitrarily by manager. Therefore, descs are duplicated for SSM and
# Attention like so:
# desc_handle->[descs_fa (all regions) | descs_ssm (all regions)].
# This is like having two "low-level views" of the same storage.
# `num_fa_descs` offset must be computed per-engine since P and D can
# have different num_blocks (and thus different FA descs counts).
ratio
=
self
.
_physical_blocks_per_logical_kv_block
# SSM may register fewer num_blocks than FA
logical_blocks
=
num_blocks
//
ratio
num_fa_descs
=
self
.
num_regions
*
num_blocks
all_descs
=
[]
for
i
,
group
in
enumerate
(
block_ids
):
stride
=
logical_blocks
if
self
.
_is_mamba_group
[
i
]
else
num_blocks
group_arr
=
np
.
asarray
(
group
)[
None
,
:]
offset
=
num_fa_descs
if
self
.
_is_mamba_group
[
i
]
else
0
all_descs
.
append
((
region_ids
*
stride
+
group_arr
+
offset
).
flatten
())
return
np
.
concatenate
(
all_descs
)
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
def
_logical_to_kernel_block_ids
(
self
,
block_ids
:
BlockIds
)
->
BlockIds
:
"""
"""
...
@@ -2523,16 +2729,22 @@ class NixlConnectorWorker:
...
@@ -2523,16 +2729,22 @@ class NixlConnectorWorker:
block_arange
=
np
.
arange
(
0
,
self
.
_physical_blocks_per_logical_kv_block
).
reshape
(
block_arange
=
np
.
arange
(
0
,
self
.
_physical_blocks_per_logical_kv_block
).
reshape
(
1
,
-
1
1
,
-
1
)
)
# Mamba blocks have no logical<>physical discrepancy
group_specs
=
self
.
kv_cache_config
.
kv_cache_groups
return
[
return
[
BlockTable
.
map_to_kernel_blocks
(
BlockTable
.
map_to_kernel_blocks
(
np
.
array
(
group
),
np
.
array
(
group
),
self
.
_physical_blocks_per_logical_kv_block
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
,
block_arange
,
).
tolist
()
).
tolist
()
for
group
in
block_ids
if
not
isinstance
(
group_specs
[
i
].
kv_cache_spec
,
MambaSpec
)
else
group
for
i
,
group
in
enumerate
(
block_ids
)
]
]
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
,
first_split
:
bool
=
True
,
mamba_view
:
bool
=
False
)
->
int
:
"""
"""
Get the block length for one K/V element (K and V have the same size).
Get the block length for one K/V element (K and V have the same size).
...
@@ -2540,10 +2752,37 @@ class NixlConnectorWorker:
...
@@ -2540,10 +2752,37 @@ class NixlConnectorWorker:
block, as K and V are in separate regions.
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
share the same region.
Similarly, for SSM-based models, state and conv are interleaved, but crucially
the their size differs.
Reference diagram:
KVCacheTensor (Shared)
/
\
/
\
/
\
Attention (FlashInfer) View Mamba View
| |
| |
+-------------------+ +-------------------+
| KVCacheTensor | | KVCacheTensor |
| | | |
|<----- page ------>| |<----- page ------->|
| size | | size |
| Key 0 | Val 0 | |Conv 0 | SSM 0 |
| Key 1 | Val 1 | |Conv 1 | SSM 1 |
| ... | ... | | ... | ... |
| Key N-2 | Val N-2 | |Conv N-2| SSM N-2 |
| Key N-1 | Val N-1 | |Conv N-1| SSM N-1 |
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
"""
assert
self
.
kv_topo
is
not
None
assert
self
.
kv_topo
is
not
None
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# For indexing only half (either just the K or V part).
# For indexing only half (either just the K or V part).
if
mamba_view
:
# NOTE (NickLucche) Mamba Opt: this is already skipping the padding so
# we're only transferring the minimum required bytes.
block_len
=
self
.
_mamba_ssm_size
[
not
first_split
]
else
:
block_len
=
self
.
block_len_per_layer
[
layer_idx
]
//
2
block_len
=
self
.
block_len_per_layer
[
layer_idx
]
//
2
else
:
else
:
block_len
=
self
.
block_len_per_layer
[
layer_idx
]
block_len
=
self
.
block_len_per_layer
[
layer_idx
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment