Unverified Commit 3f83c597 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix(vllm): classify disagg decode-engine queued requests correctly in FPM (#8471)


Signed-off-by: default avatarHongkuan Zhou <hongkuanz@nvidia.com>
Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent d4ac0b59
......@@ -451,7 +451,35 @@ class InstrumentedScheduler(AsyncScheduler):
)
def _compute_queued(self) -> QueuedRequestMetrics:
"""Single-pass aggregation over self.waiting -- no intermediate list."""
"""Single-pass aggregation over ``self.waiting`` and ``self.skipped_waiting``.
vLLM's scheduler parks requests in two queues:
* ``self.waiting`` holds requests in ``WAITING`` (new, never scheduled)
and ``PREEMPTED`` (were decoding, evicted back for memory) states.
* ``self.skipped_waiting`` holds "blocked-waiting" requests awaiting an
async precondition — see ``Scheduler._is_blocked_waiting_status`` /
``Scheduler._enqueue_waiting_request``:
WAITING_FOR_FSM -- grammar/structured-output compile
WAITING_FOR_REMOTE_KVS -- disagg decode-engine KV transfer
WAITING_FOR_STREAMING_REQ -- streaming request handshake
A ``WAITING_FOR_REMOTE_KVS`` request is a **decode** request: the
prefill engine has already computed its KV and is transferring it; once
finished the request goes straight to decode without a local prefill
step. ``num_computed_tokens`` is pre-set to the transferred KV length
(see ``Scheduler.schedule`` at the ``load_kv_async`` branch), so it is
the correct decode-KV-context value for FPM purposes.
``WAITING_FOR_FSM`` / ``WAITING_FOR_STREAMING_REQ`` have no KV computed
yet — they are queued prefill requests blocked on a precondition.
Only iterating ``self.waiting`` (the previous behaviour) silently
misses every ``WAITING_FOR_REMOTE_KVS`` request on the decode engine
in disaggregated serving, and misclassifies it as queued prefill if it
ever transiently appears in ``self.waiting``.
"""
prefill = WelfordAccumulator()
decode_kv = WelfordAccumulator()
......@@ -461,6 +489,17 @@ class InstrumentedScheduler(AsyncScheduler):
else:
prefill.add(request.num_tokens)
for request in self.skipped_waiting:
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# Disagg decode side: KV already computed on the prefill
# engine and being transferred. Next schedule() step will
# start generating -- count as queued decode.
decode_kv.add(request.num_computed_tokens)
else:
# WAITING_FOR_FSM / WAITING_FOR_STREAMING_REQ: no KV yet,
# essentially a queued prefill awaiting a precondition.
prefill.add(request.num_tokens)
return QueuedRequestMetrics(
num_prefill_requests=prefill.n,
sum_prefill_tokens=prefill.s,
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for ``InstrumentedScheduler._compute_queued`` classification.
Focus: correct handling of ``self.skipped_waiting`` and disaggregated-serving
request states. The production scheduler is heavy to construct (needs a real
``VllmConfig`` + ``KVCacheConfig`` + ``StructuredOutputManager``), so these
tests invoke ``_compute_queued`` as an unbound method against a minimal stub
built with ``object.__new__`` — this exercises the real function body without
spinning up vLLM engine internals.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from vllm.v1.request import RequestStatus # noqa: E402
# Module-level import: triggers real site-packages ``vllm`` to load before
# pytest's rootpath insertion adds ``components/src/dynamo`` to ``sys.path``
# (which shadows the real ``vllm`` with the ``dynamo.vllm`` submodule for any
# later bare ``import vllm``). Mirrors the pattern in ``test_vllm_unit.py``,
# which imports ``dynamo.vllm.args`` at module level for the same reason.
# If this import is deferred to inside a test body, the real ``vllm`` will
# not be resolvable and ``instrumented_scheduler`` will fail to load with
# ``ModuleNotFoundError: No module named 'vllm.sampling_params'``.
from dynamo.vllm.instrumented_scheduler import InstrumentedScheduler # noqa: E402
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.pre_merge,
]
def _make_request(status, num_tokens: int, num_computed_tokens: int = 0):
"""Build a minimal stand-in for ``vllm.v1.request.Request``.
Only the three attributes read by ``_compute_queued`` are populated.
"""
return SimpleNamespace(
status=status,
num_tokens=num_tokens,
num_computed_tokens=num_computed_tokens,
)
def _run_compute_queued(waiting, skipped_waiting):
"""Invoke the real ``InstrumentedScheduler._compute_queued`` on a stub.
Bypasses ``__init__`` (which needs full vLLM config) and populates only
the two attributes the method reads.
"""
stub = InstrumentedScheduler.__new__(InstrumentedScheduler)
stub.waiting = waiting
stub.skipped_waiting = skipped_waiting
return InstrumentedScheduler._compute_queued(stub)
# ---------------------------------------------------------------------------
# self.waiting classification (existing behaviour — regression coverage)
# ---------------------------------------------------------------------------
def test_waiting_new_requests_count_as_queued_prefill():
q = _run_compute_queued(
waiting=[
_make_request(RequestStatus.WAITING, num_tokens=100),
_make_request(RequestStatus.WAITING, num_tokens=200),
],
skipped_waiting=[],
)
assert q.num_prefill_requests == 2
assert q.sum_prefill_tokens == 300
assert q.num_decode_requests == 0
assert q.sum_decode_kv_tokens == 0
def test_waiting_preempted_requests_count_as_queued_decode():
q = _run_compute_queued(
waiting=[
_make_request(
RequestStatus.PREEMPTED, num_tokens=512, num_computed_tokens=480
),
_make_request(
RequestStatus.PREEMPTED, num_tokens=256, num_computed_tokens=240
),
],
skipped_waiting=[],
)
assert q.num_prefill_requests == 0
assert q.sum_prefill_tokens == 0
assert q.num_decode_requests == 2
# sum_decode_kv_tokens = sum of num_computed_tokens
assert q.sum_decode_kv_tokens == 720
# ---------------------------------------------------------------------------
# self.skipped_waiting classification (the fix)
# ---------------------------------------------------------------------------
def test_skipped_waiting_for_remote_kvs_counts_as_queued_decode():
"""Disagg decode-engine: request has KV being transferred; should count
as queued decode, not queued prefill.
"""
q = _run_compute_queued(
waiting=[],
skipped_waiting=[
_make_request(
RequestStatus.WAITING_FOR_REMOTE_KVS,
num_tokens=1000,
num_computed_tokens=1000,
),
_make_request(
RequestStatus.WAITING_FOR_REMOTE_KVS,
num_tokens=500,
num_computed_tokens=500,
),
],
)
# Must NOT be classified as prefill.
assert q.num_prefill_requests == 0
assert q.sum_prefill_tokens == 0
# Classified as decode with KV = num_computed_tokens.
assert q.num_decode_requests == 2
assert q.sum_decode_kv_tokens == 1500
def test_skipped_waiting_for_fsm_counts_as_queued_prefill():
"""Structured-output FSM compile wait has no KV computed yet; prefill."""
q = _run_compute_queued(
waiting=[],
skipped_waiting=[
_make_request(RequestStatus.WAITING_FOR_FSM, num_tokens=128),
],
)
assert q.num_prefill_requests == 1
assert q.sum_prefill_tokens == 128
assert q.num_decode_requests == 0
assert q.sum_decode_kv_tokens == 0
def test_skipped_waiting_for_streaming_req_counts_as_queued_prefill():
q = _run_compute_queued(
waiting=[],
skipped_waiting=[
_make_request(RequestStatus.WAITING_FOR_STREAMING_REQ, num_tokens=64),
],
)
assert q.num_prefill_requests == 1
assert q.sum_prefill_tokens == 64
assert q.num_decode_requests == 0
# ---------------------------------------------------------------------------
# Mixed scenarios -- the realistic disagg decode engine picture
# ---------------------------------------------------------------------------
def test_mixed_disagg_decode_engine_snapshot():
"""Realistic decode-engine snapshot: some local preempts in ``self.waiting``
plus many ``WAITING_FOR_REMOTE_KVS`` in ``self.skipped_waiting``.
"""
q = _run_compute_queued(
waiting=[
_make_request(
RequestStatus.PREEMPTED, num_tokens=800, num_computed_tokens=780
),
],
skipped_waiting=[
_make_request(
RequestStatus.WAITING_FOR_REMOTE_KVS,
num_tokens=1024,
num_computed_tokens=1024,
),
_make_request(
RequestStatus.WAITING_FOR_REMOTE_KVS,
num_tokens=2048,
num_computed_tokens=2048,
),
_make_request(
RequestStatus.WAITING_FOR_REMOTE_KVS,
num_tokens=512,
num_computed_tokens=512,
),
],
)
# 0 queued prefill on the decode engine under healthy disagg.
assert q.num_prefill_requests == 0
assert q.sum_prefill_tokens == 0
# 1 preempted (local decode evicted) + 3 remote-KV-waiting.
assert q.num_decode_requests == 4
assert q.sum_decode_kv_tokens == 780 + 1024 + 2048 + 512
def test_mixed_prefill_engine_snapshot():
"""Prefill engine side: all queued work is prefill across both queues."""
q = _run_compute_queued(
waiting=[
_make_request(RequestStatus.WAITING, num_tokens=100),
_make_request(RequestStatus.WAITING, num_tokens=200),
],
skipped_waiting=[
_make_request(RequestStatus.WAITING_FOR_FSM, num_tokens=300),
],
)
assert q.num_prefill_requests == 3
assert q.sum_prefill_tokens == 600
assert q.num_decode_requests == 0
def test_empty_queues():
q = _run_compute_queued(waiting=[], skipped_waiting=[])
assert q.num_prefill_requests == 0
assert q.sum_prefill_tokens == 0
assert q.num_decode_requests == 0
assert q.sum_decode_kv_tokens == 0
assert q.var_prefill_length == 0.0
assert q.var_decode_kv_tokens == 0.0
# ---------------------------------------------------------------------------
# Variance correctness across both queues
# ---------------------------------------------------------------------------
def test_variance_spans_both_queues():
"""Variance is computed over the union of both queues, not each in
isolation. Using lengths 100 and 300 → mean 200, var 10000 (population).
"""
q = _run_compute_queued(
waiting=[
_make_request(RequestStatus.WAITING, num_tokens=100),
],
skipped_waiting=[
_make_request(RequestStatus.WAITING_FOR_FSM, num_tokens=300),
],
)
assert q.num_prefill_requests == 2
assert q.sum_prefill_tokens == 400
# Population variance of [100, 300] = 10000.
assert q.var_prefill_length == pytest.approx(10000.0)
def test_decode_variance_spans_both_queues():
"""Decode variance mixes local-preempted (``self.waiting``) and
remote-KV-waiting (``self.skipped_waiting``) into one accumulator.
KV lengths 500 and 1500 → mean 1000, population variance 250000.
"""
q = _run_compute_queued(
waiting=[
_make_request(
RequestStatus.PREEMPTED, num_tokens=520, num_computed_tokens=500
),
],
skipped_waiting=[
_make_request(
RequestStatus.WAITING_FOR_REMOTE_KVS,
num_tokens=1500,
num_computed_tokens=1500,
),
],
)
assert q.num_decode_requests == 2
assert q.sum_decode_kv_tokens == 2000
assert q.var_decode_kv_tokens == pytest.approx(250000.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment