Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
85f671b8
Unverified
Commit
85f671b8
authored
Mar 20, 2026
by
Santino Ramos
Committed by
GitHub
Mar 20, 2026
Browse files
[Model Runner V2] Support Streaming Inputs (#37028)
Signed-off-by:
Santino Ramos
<
elsantinoramos@gmail.com
>
parent
8bc6b5cd
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
263 additions
and
12 deletions
+263
-12
tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py
.../v1/streaming_input/test_gpu_model_runner_v2_streaming.py
+207
-0
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+4
-2
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+6
-1
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+18
-6
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+23
-0
vllm/v1/worker/gpu/model_states/interface.py
vllm/v1/worker/gpu/model_states/interface.py
+2
-1
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+3
-2
No files found.
tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py
0 → 100644
View file @
85f671b8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for MRv2 GPUModelRunner.add_requests streaming input support."""
from
unittest.mock
import
Mock
import
pytest
import
torch
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
,
)
from
vllm.v1.worker.gpu.model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu.states
import
RequestState
pytestmark
=
pytest
.
mark
.
cpu_test
@
pytest
.
fixture
def
mock_model_runner_with_req_states
():
"""Create a mock MRv2 GPUModelRunner with a real RequestState."""
runner
=
Mock
(
spec
=
GPUModelRunner
)
runner
.
req_states
=
RequestState
(
max_num_reqs
=
10
,
max_model_len
=
1024
,
max_num_batched_tokens
=
1024
,
num_speculative_steps
=
0
,
vocab_size
=
32000
,
device
=
torch
.
device
(
"cpu"
),
model_dtype
=
torch
.
float32
,
cache_draft_logits
=
False
,
)
runner
.
encoder_cache
=
None
runner
.
model_state
=
Mock
()
runner
.
block_tables
=
Mock
()
runner
.
lora_state
=
Mock
()
runner
.
sampler
=
None
runner
.
prompt_logprobs_worker
=
None
runner
.
is_last_pp_rank
=
False
# Mock staged writes — they use Triton kernels that require GPU
runner
.
req_states
.
apply_staged_writes
=
Mock
()
# Bind the real methods to our mock
runner
.
_remove_request
=
GPUModelRunner
.
_remove_request
.
__get__
(
runner
)
runner
.
add_requests
=
GPUModelRunner
.
add_requests
.
__get__
(
runner
)
return
runner
def
_make_scheduler_output
(
new_reqs
):
return
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs
,
scheduled_cached_reqs
=
CachedRequestData
.
make_empty
(),
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
)
def
test_e2e_streaming_request_update_basic_flow
(
mock_model_runner_with_req_states
,
):
"""Test that streaming sessions are updated correctly.
This test validates that when a streaming session is updated with new
prompt tokens:
1. The old request state is removed (no free_indices leak)
2. The new state is written with updated prefill_token_ids
3. model_state and block_tables are re-registered for the new state
"""
runner
=
mock_model_runner_with_req_states
req_states
=
runner
.
req_states
req_id
=
"streaming_req_0"
initial_free
=
len
(
req_states
.
free_indices
)
# Step 1: Add initial request with 3 prompt tokens, all computed
initial_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prefill_token_ids
=
[
1
,
2
,
3
],
mm_features
=
[],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
],),
num_computed_tokens
=
3
,
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
initial_req_data
]))
assert
req_id
in
req_states
.
req_id_to_index
assert
len
(
req_states
.
free_indices
)
==
initial_free
-
1
# Step 2: Create streaming update with extended prompt
# The scheduler has already set prefill_token_ids to the full sequence
# (original prompt + intermediate output + new prompt tokens)
updated_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prefill_token_ids
=
[
1
,
2
,
3
,
10
,
4
,
5
],
mm_features
=
[],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
,
1
],),
num_computed_tokens
=
4
,
# 3 original prompt + 1 intermediate output
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
updated_req_data
]))
# Step 3: Verify no free_indices leak (old slot recycled)
assert
len
(
req_states
.
free_indices
)
==
initial_free
-
1
# Verify the request is still tracked with exactly one index
assert
req_id
in
req_states
.
req_id_to_index
assert
sum
(
1
for
v
in
req_states
.
index_to_req_id
.
values
()
if
v
==
req_id
)
==
1
# Verify state was updated with new values
new_idx
=
req_states
.
req_id_to_index
[
req_id
]
assert
req_states
.
prompt_len
.
np
[
new_idx
]
==
3
assert
req_states
.
prefill_len
.
np
[
new_idx
]
==
6
assert
req_states
.
num_computed_prefill_tokens
[
new_idx
]
==
4
# Verify model_state and block_tables were re-registered
runner
.
model_state
.
add_request
.
assert_called_with
(
new_idx
,
updated_req_data
)
runner
.
block_tables
.
append_block_ids
.
assert_called_with
(
new_idx
,
([
0
,
1
],),
overwrite
=
True
)
def
test_e2e_streaming_with_multimodal_features
(
mock_model_runner_with_req_states
,
):
"""Test that streaming sessions with multimodal features are updated.
This test validates that when a streaming session with mm features
is updated:
1. The old request state is removed (no free_indices leak)
2. encoder_cache is cleaned up and re-registered with new mm_features
3. model_state is re-registered (recomputes M-RoPE positions etc.)
"""
runner
=
mock_model_runner_with_req_states
req_states
=
runner
.
req_states
req_id
=
"streaming_mm_req_0"
initial_free
=
len
(
req_states
.
free_indices
)
# Enable encoder_cache for multimodal
runner
.
encoder_cache
=
Mock
()
# Step 1: Add initial request with one audio feature
mm_feature_1
=
Mock
()
initial_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
],
prefill_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
],
mm_features
=
[
mm_feature_1
],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
],),
num_computed_tokens
=
14
,
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
initial_req_data
]))
assert
req_id
in
req_states
.
req_id_to_index
# Reset mocks to track only the streaming update calls
runner
.
encoder_cache
.
reset_mock
()
runner
.
model_state
.
reset_mock
()
# Step 2: Create streaming update with additional multimodal feature
# The scheduler has folded the intermediate output (100) into
# prefill_token_ids and added a new audio chunk
mm_feature_2
=
Mock
()
updated_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
],
prefill_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
,
100
]
+
[
0
]
*
5
+
[
5
],
mm_features
=
[
mm_feature_1
,
mm_feature_2
],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
,
1
],),
num_computed_tokens
=
14
,
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
updated_req_data
]))
# Step 3: Verify no free_indices leak
assert
len
(
req_states
.
free_indices
)
==
initial_free
-
1
assert
sum
(
1
for
v
in
req_states
.
index_to_req_id
.
values
()
if
v
==
req_id
)
==
1
# Verify encoder_cache was cleaned up and re-registered
runner
.
encoder_cache
.
remove_request
.
assert_called_once_with
(
req_id
)
runner
.
encoder_cache
.
add_request
.
assert_called_once_with
(
req_id
,
[
mm_feature_1
,
mm_feature_2
]
)
# Verify model_state was re-registered with new data
new_idx
=
req_states
.
req_id_to_index
[
req_id
]
runner
.
model_state
.
add_request
.
assert_called_once_with
(
new_idx
,
updated_req_data
)
# Verify updated prefill length
assert
req_states
.
prefill_len
.
np
[
new_idx
]
==
21
vllm/model_executor/models/whisper_causal.py
View file @
85f671b8
...
...
@@ -150,7 +150,9 @@ def create_whisper_attention_backend_with_block_pooling(
new_common_attn_metadata
.
query_start_loc
*=
block_pool_size
new_common_attn_metadata
.
query_start_loc_cpu
*=
block_pool_size
new_common_attn_metadata
.
seq_lens
*=
block_pool_size
if
new_common_attn_metadata
.
_seq_lens_cpu
is
not
None
:
new_common_attn_metadata
.
_seq_lens_cpu
*=
block_pool_size
if
new_common_attn_metadata
.
_num_computed_tokens_cpu
is
not
None
:
new_common_attn_metadata
.
_num_computed_tokens_cpu
*=
block_pool_size
new_common_attn_metadata
.
num_actual_tokens
*=
block_pool_size
new_common_attn_metadata
.
max_query_len
*=
block_pool_size
...
...
vllm/v1/worker/gpu/attn_utils.py
View file @
85f671b8
...
...
@@ -111,6 +111,7 @@ def _reshape_kv_cache(
kv_cache_config
:
KVCacheConfig
,
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
AttentionBackend
],
cache_dtype
:
str
,
)
->
dict
[
str
,
torch
.
Tensor
]:
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
...
...
@@ -127,6 +128,7 @@ def _reshape_kv_cache(
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
cache_dtype
,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
...
...
@@ -155,9 +157,12 @@ def init_kv_cache(
kv_cache_config
:
KVCacheConfig
,
attn_backends
:
dict
[
str
,
AttentionBackend
],
device
:
torch
.
device
,
cache_dtype
:
str
,
)
->
dict
[
str
,
torch
.
Tensor
]:
kv_cache_raw_tensors
=
_allocate_kv_cache
(
kv_cache_config
,
device
)
kv_caches
=
_reshape_kv_cache
(
kv_cache_config
,
kv_cache_raw_tensors
,
attn_backends
)
kv_caches
=
_reshape_kv_cache
(
kv_cache_config
,
kv_cache_raw_tensors
,
attn_backends
,
cache_dtype
)
bind_kv_cache
(
kv_caches
,
forward_context
,
runner_kv_caches
)
return
kv_caches
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
85f671b8
...
...
@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_config
,
self
.
attn_backends
,
self
.
device
,
self
.
cache_config
.
cache_dtype
,
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
...
...
@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return
cuda_graph_size
def
_remove_request
(
self
,
req_id
:
str
)
->
bool
:
if
not
self
.
req_states
.
remove_request
(
req_id
):
return
False
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_cache
.
remove_request
(
req_id
)
if
self
.
prompt_logprobs_worker
is
not
None
:
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
return
True
def
finish_requests
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
finished_req_ids
=
scheduler_output
.
finished_req_ids
preempted_req_ids
=
scheduler_output
.
preempted_req_ids
if
preempted_req_ids
:
finished_req_ids
=
finished_req_ids
.
union
(
preempted_req_ids
)
for
req_id
in
finished_req_ids
:
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_cache
.
remove_request
(
req_id
)
if
self
.
prompt_logprobs_worker
is
not
None
:
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
self
.
_remove_request
(
req_id
)
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
if
self
.
encoder_cache
is
not
None
:
...
...
@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
new_req_data
.
prompt_token_ids
is
not
None
assert
new_req_data
.
prefill_token_ids
is
not
None
req_id
=
new_req_data
.
req_id
# Streaming input update: request already exists from a prior
# chunk. Remove old state so it can be cleanly re-added below
# with the updated prompt_token_ids and mm_features.
self
.
_remove_request
(
req_id
)
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
self
.
req_states
.
add_request
(
req_id
=
req_id
,
...
...
vllm/v1/worker/gpu/model_states/default.py
View file @
85f671b8
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.tasks
import
GenerationTask
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
...
...
@@ -61,6 +62,28 @@ class DefaultModelState(ModelState):
device
=
self
.
device
,
)
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
from
vllm.model_executor.models.interfaces
import
(
supports_realtime
,
supports_transcription
,
)
from
vllm.model_executor.models.interfaces_base
import
is_text_generation_model
supported_tasks
=
list
[
GenerationTask
]()
if
is_text_generation_model
(
self
.
model
):
supported_tasks
.
append
(
"generate"
)
if
supports_transcription
(
self
.
model
):
if
self
.
model
.
supports_transcription_only
:
return
(
"transcription"
,)
supported_tasks
.
append
(
"transcription"
)
if
supports_realtime
(
self
.
model
):
supported_tasks
.
append
(
"realtime"
)
return
tuple
(
supported_tasks
)
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
if
self
.
rope_state
is
not
None
:
assert
new_req_data
.
prefill_token_ids
is
not
None
...
...
vllm/v1/worker/gpu/model_states/interface.py
View file @
85f671b8
...
...
@@ -28,8 +28,9 @@ class ModelState(ABC):
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
r
eturn
(
"generate"
,)
r
aise
NotImplementedError
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
return
None
...
...
vllm/v1/worker/gpu/states.py
View file @
85f671b8
...
...
@@ -109,13 +109,14 @@ class RequestState:
self
.
all_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
def
remove_request
(
self
,
req_id
:
str
)
->
bool
:
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_idx
is
None
:
# Request not found.
return
return
False
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
free_indices
.
append
(
req_idx
)
return
True
def
any_prefills
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment