Unverified Commit f5c081d4 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[PD][Nixl] Add support for hybrid SSM-FA models (#36687)

parent c88ea833
......@@ -18,11 +18,19 @@ dp_ep_configs=(
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
)
# Select config array based on DP_EP env var
if [[ -n "${DP_EP:-}" ]]; then
configs=("${dp_ep_configs[@]}")
echo "DP_EP is set, using dp_ep_configs"
elif [[ -n "${HYBRID_SSM:-}" ]]; then
configs=("${hybrid_ssm_configs[@]}")
echo "HYBRID_SSM is set, using hybrid_ssm_configs."
else
configs=("${tp_configs[@]}")
fi
......
......@@ -18,6 +18,7 @@ EXPECTED_VALUES = {
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": 0.84,
}
SIMPLE_PROMPT = (
......
......@@ -53,7 +53,13 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import set_kv_cache_layout
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.kv_cache_interface import (
AttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheTensor,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
......@@ -332,8 +338,20 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
kv_cache_groups = [
KVCacheGroupSpec(
["layer0", "layer1", "layer2"],
FullAttentionSpec(
block_size=16,
num_kv_heads=4,
head_size=16,
dtype=torch.float16,
),
)
]
kv_cache_config = KVCacheConfig(
num_blocks=2, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
)
prefill_connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
......@@ -437,7 +455,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
test_shape = self.attn_backends[0].get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
......@@ -447,7 +465,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
attn_backends=self.attn_backends,
tensor_shape=test_shape,
)
......@@ -501,6 +519,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
ssm_sizes=(0, 0),
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
......@@ -951,6 +970,7 @@ class TestNixlHandshake:
block_lens=worker.block_len_per_layer,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
ssm_sizes=(0, 0),
)
with pytest.raises(RuntimeError):
......@@ -1006,6 +1026,7 @@ class TestNixlHandshake:
block_lens=[i * 2 for i in worker.block_len_per_layer],
kv_cache_layout="HND",
block_size=worker.block_size,
ssm_sizes=(0, 0),
)
# We don't check layout for homogeneous TP and MLA for now, as the
......@@ -1496,9 +1517,47 @@ def test_register_kv_caches(
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
mock_get_attn_backends.return_value = [backend_cls]
num_layers = 32
block_size = 16
num_blocks = 8
num_heads = 4
head_size = 16
# TODO (NickLucche) the fact that connector depends on kv_cache_config for init
# but cross-layer preference cant be inferred prior to creating kv_cache_config
# is a bit awkward.
dummy_connector = NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=block_size),
)
kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=num_heads,
head_size=head_size,
dtype=torch.float16,
)
if dummy_connector.prefer_cross_layer_blocks:
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["all-layers"],
)
for _ in range(num_layers)
],
kv_cache_groups=[KVCacheGroupSpec(["all-layers"], kv_cache_spec)],
)
else:
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(["layer0", "layer1", "layer2"], kv_cache_spec)
],
)
# Create connector
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
......@@ -1526,35 +1585,6 @@ def test_register_kv_caches(
or connector.prefer_cross_layer_blocks
)
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config = make_kv_cache_config(
block_size=block_size, num_blocks=num_blocks
)
connector.connector_worker.kv_cache_config = worker_kv_cache_config
connector.connector_worker.num_blocks = worker_kv_cache_config.num_blocks
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)
with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
......@@ -1586,12 +1616,8 @@ def test_register_kv_caches(
expected_blocks_count = 8
kv_caches = {"all-layers": cross_layers_kv_cache}
else:
# Create test kv cache tensors using proper backend shape
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
......@@ -2261,7 +2287,7 @@ def test_compatibility_hash_validation(
kv_cache_spec = cast(
AttentionSpec, kv_cache_config.kv_cache_groups[0].kv_cache_spec
)
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
kv_cache_shape = decode_worker.attn_backends[0].get_kv_cache_shape(
num_blocks=kv_cache_config.num_blocks,
block_size=kv_cache_spec.block_size,
num_kv_heads=kv_cache_spec.num_kv_heads,
......@@ -2269,10 +2295,14 @@ def test_compatibility_hash_validation(
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=kv_cache_spec.dtype)
# Build kv_caches from the actual layer names in kv_cache_config so that
# _layer_specs lookups in register_kv_caches always find a matching key.
layer_names = [
name for group in kv_cache_config.kv_cache_groups for name in group.layer_names
]
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
name: shared_tensor if i % 2 == 0 else unique_tensor
for i, name in enumerate(layer_names)
}
decode_connector.register_kv_caches(kv_caches)
......@@ -2312,6 +2342,7 @@ def test_compatibility_hash_validation(
block_lens=[4096 * prefill_block_size], # slot_size * block_size
kv_cache_layout="HND",
block_size=prefill_block_size,
ssm_sizes=(0, 0),
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,
......@@ -2391,7 +2422,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
attn_backends=[backend],
tensor_shape=test_shape,
)
......
......@@ -74,6 +74,8 @@ def test_logical_to_kernel_block_ids_with_hma():
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2
# FA + SW groups (neither is MambaSpec, so both get expanded)
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True)
# Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]]
......@@ -201,3 +203,113 @@ def test_nixl_metadata_hma_block_ids_structure():
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]
@pytest.mark.cpu_test
def test_get_block_descs_ids_hybrid_ssm():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)
worker = object.__new__(NixlConnectorWorker)
num_blocks = 100
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = 1
# num_descs = num_regions * num_blocks (no blocks_first doubling)
worker.num_descs = 2 * num_blocks
fa_blocks = [3, 5]
ssm_blocks = [1, 2]
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))
# FA group: stride=num_blocks=100, offset=0
# region0: [3, 5], region1: [103, 105]
# SSM group: stride=logical_blocks=100 (=num_blocks/ratio=100/1),
# offset=num_descs=200
# region0: [201, 202], region1: [301, 302]
expected = [3, 5, 103, 105, 201, 202, 301, 302]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
@pytest.mark.cpu_test
def test_get_block_descs_ids_kernel_block_mismatch():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker,
)
worker = object.__new__(NixlConnectorWorker)
ratio = 4
logical_blocks = 100
num_blocks = logical_blocks * ratio # 400 kernel blocks
engine_id = "test-engine"
worker.num_regions = 2
worker.dst_num_blocks = {engine_id: num_blocks}
worker._has_mamba = True
worker._is_mamba_group = [False, True]
worker._physical_blocks_per_logical_kv_block = ratio
worker.num_descs = 2 * num_blocks # 800
fa_blocks = [3, 7] # kernel-level block IDs
ssm_blocks = [1, 2] # logical block IDs
result = worker._get_block_descs_ids(engine_id, (fa_blocks, ssm_blocks))
# FA group: stride=num_blocks=400, offset=0
# region0: [3, 7], region1: [403, 407]
# SSM group: stride=logical_blocks=400//4=100, offset=num_descs=800
# region0: [801, 802], region1: [901, 902]
expected = [3, 7, 403, 407, 801, 802, 901, 902]
assert list(result) == expected, f"Expected {expected}, got {list(result)}"
@pytest.mark.cpu_test
def test_nixl_metadata_hybrid_ssm_block_ids():
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
groups with different block counts (kernel mismatch active)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
)
metadata = NixlConnectorMetadata()
# FA: 8 kernel blocks (2 logical * ratio=4), SSM: 2 logical blocks
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7]
ssm_blocks = [0, 1]
metadata.add_new_req_to_recv(
request_id="test-req-hybrid",
local_block_ids=(fa_blocks, ssm_blocks),
kv_transfer_params={
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [20, 21]),
"remote_engine_id": "remote-engine",
"remote_request_id": "prefill-test-req-hybrid",
"remote_host": "localhost",
"remote_port": 1234,
"tp_size": 1,
},
)
assert "test-req-hybrid" in metadata.reqs_to_recv
req_meta = metadata.reqs_to_recv["test-req-hybrid"]
# Verify local block IDs: different lengths per group
assert len(req_meta.local_block_ids) == 2
assert list(req_meta.local_block_ids[0]) == fa_blocks
assert list(req_meta.local_block_ids[1]) == ssm_blocks
assert len(req_meta.local_block_ids[0]) != len(req_meta.local_block_ids[1])
# Verify remote block IDs: same asymmetry preserved
assert req_meta.remote is not None
assert len(req_meta.remote.block_ids) == 2
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [20, 21]
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])
......@@ -16,10 +16,12 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.v1.kv_cache_interface import KVCacheSpec
logger = init_logger(__name__)
......@@ -328,22 +330,26 @@ class TpKVTopology:
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
attn_backends: list[type[AttentionBackend]]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
is_mamba: bool = False
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
_MOCK_BLOCK_SIZE = 16
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
)
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
attn_backend = self.attn_backends[0]
if not self.is_mamba:
_MOCK_BLOCK_SIZE = 16
kv_cache_shape: tuple[int, ...] = attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=_MOCK_BLOCK_SIZE, num_kv_heads=1, head_size=1
)
logger.debug("Test kv_cache_shape: %s", kv_cache_shape)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
# Hybrid SSM models assume a single blocks_first layout
self._is_kv_layout_blocks_first = self.is_mamba or (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
......@@ -360,7 +366,7 @@ class TpKVTopology:
_MOCK_NUM_LAYERS = 80
kv_cache_shape = (_MOCK_NUM_LAYERS,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
......@@ -483,6 +489,30 @@ class TpKVTopology:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size)
def get_transfer_cache_regions(
self, cache: torch.Tensor, layer_spec: "KVCacheSpec"
) -> list[torch.Tensor] | torch.Tensor:
"""Return the cache tensor(s) to register as NIXL memory regions,
also accounting for hybrid SSM models specificities.
"""
if isinstance(layer_spec, MambaSpec):
# Register the whole kv cache shared tensor, including SSM/Conv. This is
# similar to FI with the difference that SSM/Conv have different sizes
conv, ssm = cache
return [conv]
# Check may be hacky but it's matching `_update_hybrid_attention_mamba_layout`.
if self.is_mamba and cache.shape[0] == 2:
# When MAMBA is present, all backends are blocks first, so that blocks
# can be shared between attention layers and mamba layers. Runner
# `_update_hybrid_attention_mamba_layout` already adjusted strides
# for FlashAttn-like backends so its num_blocks first.
# Swap [2<>num_blocks] dims to get required layout for hybrid SSM.
cache = cache.transpose(0, 1)
# Regular case: backends like FA register K/V in separate regions
return cache if self.split_k_and_v else [cache]
def get_current_attn_backends(
vllm_config: VllmConfig, layer_names: list[str] | None = None
......
......@@ -564,7 +564,7 @@ class MooncakeConnectorWorker:
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
attn_backends=[backend],
)
self.async_zmq_ctx = zmq.asyncio.Context()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment