"vllm/vscode:/vscode.git/clone" did not exist on "9659bc7f271ec640da780b5ca739e261764b954b"
Unverified Commit d49f2731 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[SSM/Mamba] Follow-up: N-1 prefill for P/D disaggregation (#37310)

parent b21d3843
...@@ -2007,7 +2007,7 @@ def test_transfer_failure_logging( ...@@ -2007,7 +2007,7 @@ def test_transfer_failure_logging(
connector = NixlConnector( connector = NixlConnector(
vllm_config, vllm_config,
KVConnectorRole.WORKER, KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16, hma_enabled=enable_hma), make_kv_cache_config(block_size=16, swa_enabled=enable_hma),
) )
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, vllm_config,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA.""" """Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill."""
from unittest.mock import patch from unittest.mock import patch
...@@ -14,24 +14,26 @@ from vllm.v1.core.single_type_kv_cache_manager import ( ...@@ -14,24 +14,26 @@ from vllm.v1.core.single_type_kv_cache_manager import (
) )
from .utils import ( from .utils import (
create_request,
create_vllm_config, create_vllm_config,
make_kv_cache_config, make_kv_cache_config,
make_nixl_scheduler,
) )
@pytest.mark.cpu_test @pytest.mark.cpu_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"hma_enabled,expected_sw_sizes", "swa_enabled,expected_sw_sizes",
[ [
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) # SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
(True, [0, 128 + 1]), (True, [0, 128 + 1]),
# HMA disabled: only FullAttentionSpec (0) # SWA disabled: only FullAttentionSpec (0)
(False, [0]), (False, [0]),
], ],
) )
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on HMA enabled/disabled.""" """Test sw_sizes is correctly computed based on SWA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
...@@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): ...@@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
vllm_config = create_vllm_config(block_size=block_size) vllm_config = create_vllm_config(block_size=block_size)
# SW 2048 tokens=>128 blocks # SW 2048 tokens=>128 blocks
kv_cache_config = make_kv_cache_config( kv_cache_config = make_kv_cache_config(
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048 block_size=block_size, swa_enabled=swa_enabled, sw_size=2048
) )
scheduler = NixlConnectorScheduler( scheduler = NixlConnectorScheduler(
...@@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma(): ...@@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma():
# So each logical block maps to 2 kernel blocks eg [0]->[0,1] # So each logical block maps to 2 kernel blocks eg [0]->[0,1]
worker._physical_blocks_per_logical_kv_block = 2 worker._physical_blocks_per_logical_kv_block = 2
# FA + SW groups (neither is MambaSpec, so both get expanded) # FA + SW groups (neither is MambaSpec, so both get expanded)
worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True) worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True)
# Test conversion: FA + SW group # Test conversion: FA + SW group
logical_block_ids = [[0, 1, 2], [3, 4]] logical_block_ids = [[0, 1, 2], [3, 4]]
...@@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids(): ...@@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids():
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
assert list(req_meta.remote.block_ids[1]) == [20, 21] assert list(req_meta.remote.block_ids[1]) == [20, 21]
assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1]) assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1])
# ── Mamba N-1 prefill tests ──────────────────────────────────────────────
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"has_mamba,is_hma_required,expected_count",
[
(True, True, 9),
(False, False, 10),
(False, True, 10),
],
ids=["mamba", "fa_only", "swa_only"],
)
def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count):
"""D-side: Mamba gets N-1 matched tokens, non-Mamba gets N."""
sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required)
req = create_request(num_tokens=10, do_remote_prefill=True)
count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert count == expected_count
assert is_async is True
@pytest.mark.cpu_test
def test_mamba_n1_p_side_truncation():
"""P-side: Mamba truncates prompt to N-1, sets max_tokens=1.
Also verifies idempotency (calling again is a no-op) which is
needed for preemption safety via the _p_side_truncated guard,
and that non-Mamba models skip truncation entirely.
"""
sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True)
req = create_request(num_tokens=10, do_remote_decode=True)
req.max_tokens = 128
original_len = len(req.prompt_token_ids)
count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert count == 0
assert is_async is False
assert len(req.prompt_token_ids) == original_len - 1
assert req.num_prompt_tokens == original_len - 1
assert req.max_tokens == 1
assert req.kv_transfer_params["_p_side_truncated"] is True
# Idempotency: second call must not truncate further
sched.get_num_new_matched_tokens(req, num_computed_tokens=0)
assert len(req.prompt_token_ids) == original_len - 1
# Non-Mamba: truncation is skipped
fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False)
fa_req = create_request(num_tokens=10, do_remote_decode=True)
fa_original = len(fa_req.prompt_token_ids)
fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0)
assert len(fa_req.prompt_token_ids) == fa_original
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma",
[
(True, True, True, True),
(True, False, False, True),
(False, False, False, False),
],
ids=["fa_swa_mamba", "fa_swa_only", "fa_only"],
)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_has_mamba_init(
mock_platform,
swa_enabled,
mamba_enabled,
expected_has_mamba,
expected_is_hma,
):
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
mock_platform.device_type = "cpu"
block_size = 16
vllm_config = create_vllm_config(block_size=block_size)
# VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config
# is set; override so we can test the scheduler's own derivation.
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=block_size,
swa_enabled=swa_enabled,
mamba_enabled=mamba_enabled,
)
scheduler = NixlConnectorScheduler(
vllm_config=vllm_config,
engine_id="test-engine",
kv_cache_config=kv_cache_config,
)
assert scheduler._has_mamba is expected_has_mamba
assert scheduler._is_hma_required is expected_is_hma
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from unittest.mock import patch
import pytest import pytest
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.request import FinishReason, RequestStatus from vllm.v1.request import FinishReason, RequestStatus
from .utils import ( from .utils import (
...@@ -13,6 +18,7 @@ from .utils import ( ...@@ -13,6 +18,7 @@ from .utils import (
create_request, create_request,
create_scheduler, create_scheduler,
create_vllm_config, create_vllm_config,
make_kv_cache_config,
) )
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
...@@ -579,3 +585,73 @@ def test_cannot_recv(): ...@@ -579,3 +585,73 @@ def test_cannot_recv():
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule() _ = scheduler.schedule()
assert_scheduler_empty(scheduler) assert_scheduler_empty(scheduler)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
def test_p_side_chunked_prefill_mamba(mock_platform):
"""P-side integration: Mamba N-1 truncation + chunked prefill completes.
A 64-token P-side request is truncated to 63 by the N-1 fix, then
chunked into two prefill steps (32 + 31) and finishes with
LENGTH_CAPPED because max_tokens is set to 1.
"""
mock_platform.device_type = "cpu"
BATCH_SIZE = 32
NUM_TOKENS = 64
BLOCK_SIZE = 16
vllm_config = create_vllm_config(
max_num_batched_tokens=BATCH_SIZE,
block_size=BLOCK_SIZE,
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
kv_cache_config = make_kv_cache_config(
block_size=BLOCK_SIZE,
mamba_enabled=True,
num_blocks=10000,
)
scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config)
request = create_request(
num_tokens=NUM_TOKENS,
do_remote_decode=True,
block_size=BLOCK_SIZE,
)
request.max_tokens = 128
scheduler.add_request(request)
request_id = request.request_id
# ── Step 1: first chunk ──
scheduler_output = scheduler.schedule()
assert len(request.prompt_token_ids) == NUM_TOKENS - 1
assert request.max_tokens == 1
assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE
assert request.num_computed_tokens == BATCH_SIZE
# Model returns no tokens for intermediate prefill chunk
intermediate_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[]],
)
scheduler.update_from_output(scheduler_output, intermediate_output)
# ── Step 2: remaining chunk ──
scheduler_output = scheduler.schedule()
remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31
assert scheduler_output.num_scheduled_tokens[request_id] == remaining
assert request.num_computed_tokens == NUM_TOKENS - 1
# Prefill complete: model generates 1 decode token
final_output = create_model_runner_output([request])
engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output)
# max_tokens=1 → request finishes with LENGTH
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
assert outputs[0].finish_reason == FinishReason.LENGTH
...@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -37,6 +37,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
KVCacheGroupSpec, KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec, SlidingWindowSpec,
) )
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
...@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector( ...@@ -423,7 +424,8 @@ KVConnectorFactory.register_connector(
def make_kv_cache_config( def make_kv_cache_config(
block_size: int, block_size: int,
hma_enabled: bool = False, swa_enabled: bool = False,
mamba_enabled: bool = False,
sw_size: int = 128, sw_size: int = 128,
num_blocks: int = 100, num_blocks: int = 100,
) -> KVCacheConfig: ) -> KVCacheConfig:
...@@ -438,7 +440,7 @@ def make_kv_cache_config( ...@@ -438,7 +440,7 @@ def make_kv_cache_config(
), ),
) )
] ]
if hma_enabled: if swa_enabled:
kv_cache_groups.append( kv_cache_groups.append(
KVCacheGroupSpec( KVCacheGroupSpec(
["layer1", "layer3"], ["layer1", "layer3"],
...@@ -451,6 +453,32 @@ def make_kv_cache_config( ...@@ -451,6 +453,32 @@ def make_kv_cache_config(
), ),
) )
) )
if mamba_enabled:
kv_cache_groups.append(
KVCacheGroupSpec(
["mamba0", "mamba1"],
MambaSpec(
block_size=block_size,
shapes=((16,), (16,)),
dtypes=(torch.float16,),
),
)
)
return KVCacheConfig( return KVCacheConfig(
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
) )
def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
"""Create a NixlConnectorScheduler via __new__ (skipping __init__).
Only sets the two flags needed by the N-1 prefill logic.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorScheduler,
)
sched = object.__new__(NixlConnectorScheduler)
sched._has_mamba = has_mamba
sched._is_hma_required = is_hma_required
return sched
...@@ -572,6 +572,10 @@ class NixlConnectorScheduler: ...@@ -572,6 +572,10 @@ class NixlConnectorScheduler:
for g in kv_cache_config.kv_cache_groups for g in kv_cache_config.kv_cache_groups
) )
) )
self._has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec)
for g in kv_cache_config.kv_cache_groups
)
logger.info("Initializing NIXL Scheduler %s", engine_id) logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
...@@ -717,6 +721,39 @@ class NixlConnectorScheduler: ...@@ -717,6 +721,39 @@ class NixlConnectorScheduler:
logger.warning("Connection listener got unexpected message %s", msg) logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if self._has_mamba and num_prompt_tokens > 1:
return num_prompt_tokens - 1
return num_prompt_tokens
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params = request.kv_transfer_params
if (
params is not None
# Guard against repeated truncation after preemption/reschedule.
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]: ) -> tuple[int, bool]:
...@@ -746,10 +783,14 @@ class NixlConnectorScheduler: ...@@ -746,10 +783,14 @@ class NixlConnectorScheduler:
if params is not None and params.get("do_remote_prefill"): if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote. # Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or [] token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens actual = self._mamba_prefill_token_count(len(token_ids))
count = actual - num_computed_tokens
if count > 0: if count > 0:
return count, True return count, True
if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)
# No remote prefill for this request. # No remote prefill for this request.
return 0, False return 0, False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment