Unverified Commit 85f671b8 authored by Santino Ramos's avatar Santino Ramos Committed by GitHub
Browse files

[Model Runner V2] Support Streaming Inputs (#37028)


Signed-off-by: default avatarSantino Ramos <elsantinoramos@gmail.com>
parent 8bc6b5cd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for MRv2 GPUModelRunner.add_requests streaming input support."""
from unittest.mock import Mock
import pytest
import torch
from vllm.v1.core.sched.output import (
CachedRequestData,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.gpu.states import RequestState
pytestmark = pytest.mark.cpu_test
@pytest.fixture
def mock_model_runner_with_req_states():
"""Create a mock MRv2 GPUModelRunner with a real RequestState."""
runner = Mock(spec=GPUModelRunner)
runner.req_states = RequestState(
max_num_reqs=10,
max_model_len=1024,
max_num_batched_tokens=1024,
num_speculative_steps=0,
vocab_size=32000,
device=torch.device("cpu"),
model_dtype=torch.float32,
cache_draft_logits=False,
)
runner.encoder_cache = None
runner.model_state = Mock()
runner.block_tables = Mock()
runner.lora_state = Mock()
runner.sampler = None
runner.prompt_logprobs_worker = None
runner.is_last_pp_rank = False
# Mock staged writes — they use Triton kernels that require GPU
runner.req_states.apply_staged_writes = Mock()
# Bind the real methods to our mock
runner._remove_request = GPUModelRunner._remove_request.__get__(runner)
runner.add_requests = GPUModelRunner.add_requests.__get__(runner)
return runner
def _make_scheduler_output(new_reqs):
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
def test_e2e_streaming_request_update_basic_flow(
mock_model_runner_with_req_states,
):
"""Test that streaming sessions are updated correctly.
This test validates that when a streaming session is updated with new
prompt tokens:
1. The old request state is removed (no free_indices leak)
2. The new state is written with updated prefill_token_ids
3. model_state and block_tables are re-registered for the new state
"""
runner = mock_model_runner_with_req_states
req_states = runner.req_states
req_id = "streaming_req_0"
initial_free = len(req_states.free_indices)
# Step 1: Add initial request with 3 prompt tokens, all computed
initial_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prefill_token_ids=[1, 2, 3],
mm_features=[],
sampling_params=None,
pooling_params=None,
block_ids=([0],),
num_computed_tokens=3,
lora_request=None,
)
runner.add_requests(_make_scheduler_output([initial_req_data]))
assert req_id in req_states.req_id_to_index
assert len(req_states.free_indices) == initial_free - 1
# Step 2: Create streaming update with extended prompt
# The scheduler has already set prefill_token_ids to the full sequence
# (original prompt + intermediate output + new prompt tokens)
updated_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prefill_token_ids=[1, 2, 3, 10, 4, 5],
mm_features=[],
sampling_params=None,
pooling_params=None,
block_ids=([0, 1],),
num_computed_tokens=4, # 3 original prompt + 1 intermediate output
lora_request=None,
)
runner.add_requests(_make_scheduler_output([updated_req_data]))
# Step 3: Verify no free_indices leak (old slot recycled)
assert len(req_states.free_indices) == initial_free - 1
# Verify the request is still tracked with exactly one index
assert req_id in req_states.req_id_to_index
assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1
# Verify state was updated with new values
new_idx = req_states.req_id_to_index[req_id]
assert req_states.prompt_len.np[new_idx] == 3
assert req_states.prefill_len.np[new_idx] == 6
assert req_states.num_computed_prefill_tokens[new_idx] == 4
# Verify model_state and block_tables were re-registered
runner.model_state.add_request.assert_called_with(new_idx, updated_req_data)
runner.block_tables.append_block_ids.assert_called_with(
new_idx, ([0, 1],), overwrite=True
)
def test_e2e_streaming_with_multimodal_features(
mock_model_runner_with_req_states,
):
"""Test that streaming sessions with multimodal features are updated.
This test validates that when a streaming session with mm features
is updated:
1. The old request state is removed (no free_indices leak)
2. encoder_cache is cleaned up and re-registered with new mm_features
3. model_state is re-registered (recomputes M-RoPE positions etc.)
"""
runner = mock_model_runner_with_req_states
req_states = runner.req_states
req_id = "streaming_mm_req_0"
initial_free = len(req_states.free_indices)
# Enable encoder_cache for multimodal
runner.encoder_cache = Mock()
# Step 1: Add initial request with one audio feature
mm_feature_1 = Mock()
initial_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4],
prefill_token_ids=[1, 2] + [0] * 10 + [3, 4],
mm_features=[mm_feature_1],
sampling_params=None,
pooling_params=None,
block_ids=([0],),
num_computed_tokens=14,
lora_request=None,
)
runner.add_requests(_make_scheduler_output([initial_req_data]))
assert req_id in req_states.req_id_to_index
# Reset mocks to track only the streaming update calls
runner.encoder_cache.reset_mock()
runner.model_state.reset_mock()
# Step 2: Create streaming update with additional multimodal feature
# The scheduler has folded the intermediate output (100) into
# prefill_token_ids and added a new audio chunk
mm_feature_2 = Mock()
updated_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4],
prefill_token_ids=[1, 2] + [0] * 10 + [3, 4, 100] + [0] * 5 + [5],
mm_features=[mm_feature_1, mm_feature_2],
sampling_params=None,
pooling_params=None,
block_ids=([0, 1],),
num_computed_tokens=14,
lora_request=None,
)
runner.add_requests(_make_scheduler_output([updated_req_data]))
# Step 3: Verify no free_indices leak
assert len(req_states.free_indices) == initial_free - 1
assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1
# Verify encoder_cache was cleaned up and re-registered
runner.encoder_cache.remove_request.assert_called_once_with(req_id)
runner.encoder_cache.add_request.assert_called_once_with(
req_id, [mm_feature_1, mm_feature_2]
)
# Verify model_state was re-registered with new data
new_idx = req_states.req_id_to_index[req_id]
runner.model_state.add_request.assert_called_once_with(new_idx, updated_req_data)
# Verify updated prefill length
assert req_states.prefill_len.np[new_idx] == 21
...@@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling(
new_common_attn_metadata.query_start_loc *= block_pool_size new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size if new_common_attn_metadata._seq_lens_cpu is not None:
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size new_common_attn_metadata._seq_lens_cpu *= block_pool_size
if new_common_attn_metadata._num_computed_tokens_cpu is not None:
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size new_common_attn_metadata.max_seq_len *= block_pool_size
......
...@@ -111,6 +111,7 @@ def _reshape_kv_cache( ...@@ -111,6 +111,7 @@ def _reshape_kv_cache(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor], kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend], attn_backends: dict[str, AttentionBackend],
cache_dtype: str,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups: for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
...@@ -127,6 +128,7 @@ def _reshape_kv_cache( ...@@ -127,6 +128,7 @@ def _reshape_kv_cache(
kv_cache_spec.block_size, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size, kv_cache_spec.head_size,
cache_dtype,
) )
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends. # FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
...@@ -155,9 +157,12 @@ def init_kv_cache( ...@@ -155,9 +157,12 @@ def init_kv_cache(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend], attn_backends: dict[str, AttentionBackend],
device: torch.device, device: torch.device,
cache_dtype: str,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends) kv_caches = _reshape_kv_cache(
kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype
)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches) bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches return kv_caches
......
...@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config, self.kv_cache_config,
self.attn_backends, self.attn_backends,
self.device, self.device,
self.cache_config.cache_dtype,
) )
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
...@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
return cuda_graph_size return cuda_graph_size
def _remove_request(self, req_id: str) -> bool:
if not self.req_states.remove_request(req_id):
return False
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
return True
def finish_requests(self, scheduler_output: SchedulerOutput) -> None: def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
finished_req_ids = scheduler_output.finished_req_ids finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids preempted_req_ids = scheduler_output.preempted_req_ids
if preempted_req_ids: if preempted_req_ids:
finished_req_ids = finished_req_ids.union(preempted_req_ids) finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids: for req_id in finished_req_ids:
self.req_states.remove_request(req_id) self._remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None: def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.encoder_cache is not None: if self.encoder_cache is not None:
...@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert new_req_data.prompt_token_ids is not None assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None assert new_req_data.prefill_token_ids is not None
req_id = new_req_data.req_id req_id = new_req_data.req_id
# Streaming input update: request already exists from a prior
# chunk. Remove old state so it can be cleanly re-added below
# with the updated prompt_token_ids and mm_features.
self._remove_request(req_id)
prompt_len = len(new_req_data.prompt_token_ids) prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request( self.req_states.add_request(
req_id=req_id, req_id=req_id,
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.tasks import GenerationTask
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
...@@ -61,6 +62,28 @@ class DefaultModelState(ModelState): ...@@ -61,6 +62,28 @@ class DefaultModelState(ModelState):
device=self.device, device=self.device,
) )
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
from vllm.model_executor.models.interfaces import (
supports_realtime,
supports_transcription,
)
from vllm.model_executor.models.interfaces_base import is_text_generation_model
supported_tasks = list[GenerationTask]()
if is_text_generation_model(self.model):
supported_tasks.append("generate")
if supports_transcription(self.model):
if self.model.supports_transcription_only:
return ("transcription",)
supported_tasks.append("transcription")
if supports_realtime(self.model):
supported_tasks.append("realtime")
return tuple(supported_tasks)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.rope_state is not None: if self.rope_state is not None:
assert new_req_data.prefill_token_ids is not None assert new_req_data.prefill_token_ids is not None
......
...@@ -28,8 +28,9 @@ class ModelState(ABC): ...@@ -28,8 +28,9 @@ class ModelState(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]: def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
return ("generate",) raise NotImplementedError
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
return None return None
......
...@@ -109,13 +109,14 @@ class RequestState: ...@@ -109,13 +109,14 @@ class RequestState:
self.all_token_ids.apply_write() self.all_token_ids.apply_write()
self.num_computed_tokens.apply_write() self.num_computed_tokens.apply_write()
def remove_request(self, req_id: str) -> None: def remove_request(self, req_id: str) -> bool:
req_idx = self.req_id_to_index.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None: if req_idx is None:
# Request not found. # Request not found.
return return False
self.index_to_req_id.pop(req_idx, None) self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx) self.free_indices.append(req_idx)
return True
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool: def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
return np.any( return np.any(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment