Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
85f671b8
Unverified
Commit
85f671b8
authored
Mar 20, 2026
by
Santino Ramos
Committed by
GitHub
Mar 20, 2026
Browse files
[Model Runner V2] Support Streaming Inputs (#37028)
Signed-off-by:
Santino Ramos
<
elsantinoramos@gmail.com
>
parent
8bc6b5cd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
263 additions
and
12 deletions
+263
-12
tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py
.../v1/streaming_input/test_gpu_model_runner_v2_streaming.py
+207
-0
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+4
-2
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+6
-1
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+18
-6
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+23
-0
vllm/v1/worker/gpu/model_states/interface.py
vllm/v1/worker/gpu/model_states/interface.py
+2
-1
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+3
-2
No files found.
tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py
0 → 100644
View file @
85f671b8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for MRv2 GPUModelRunner.add_requests streaming input support."""
from
unittest.mock
import
Mock
import
pytest
import
torch
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
,
)
from
vllm.v1.worker.gpu.model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu.states
import
RequestState
pytestmark
=
pytest
.
mark
.
cpu_test
@
pytest
.
fixture
def
mock_model_runner_with_req_states
():
"""Create a mock MRv2 GPUModelRunner with a real RequestState."""
runner
=
Mock
(
spec
=
GPUModelRunner
)
runner
.
req_states
=
RequestState
(
max_num_reqs
=
10
,
max_model_len
=
1024
,
max_num_batched_tokens
=
1024
,
num_speculative_steps
=
0
,
vocab_size
=
32000
,
device
=
torch
.
device
(
"cpu"
),
model_dtype
=
torch
.
float32
,
cache_draft_logits
=
False
,
)
runner
.
encoder_cache
=
None
runner
.
model_state
=
Mock
()
runner
.
block_tables
=
Mock
()
runner
.
lora_state
=
Mock
()
runner
.
sampler
=
None
runner
.
prompt_logprobs_worker
=
None
runner
.
is_last_pp_rank
=
False
# Mock staged writes — they use Triton kernels that require GPU
runner
.
req_states
.
apply_staged_writes
=
Mock
()
# Bind the real methods to our mock
runner
.
_remove_request
=
GPUModelRunner
.
_remove_request
.
__get__
(
runner
)
runner
.
add_requests
=
GPUModelRunner
.
add_requests
.
__get__
(
runner
)
return
runner
def
_make_scheduler_output
(
new_reqs
):
return
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs
,
scheduled_cached_reqs
=
CachedRequestData
.
make_empty
(),
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
)
def
test_e2e_streaming_request_update_basic_flow
(
mock_model_runner_with_req_states
,
):
"""Test that streaming sessions are updated correctly.
This test validates that when a streaming session is updated with new
prompt tokens:
1. The old request state is removed (no free_indices leak)
2. The new state is written with updated prefill_token_ids
3. model_state and block_tables are re-registered for the new state
"""
runner
=
mock_model_runner_with_req_states
req_states
=
runner
.
req_states
req_id
=
"streaming_req_0"
initial_free
=
len
(
req_states
.
free_indices
)
# Step 1: Add initial request with 3 prompt tokens, all computed
initial_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prefill_token_ids
=
[
1
,
2
,
3
],
mm_features
=
[],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
],),
num_computed_tokens
=
3
,
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
initial_req_data
]))
assert
req_id
in
req_states
.
req_id_to_index
assert
len
(
req_states
.
free_indices
)
==
initial_free
-
1
# Step 2: Create streaming update with extended prompt
# The scheduler has already set prefill_token_ids to the full sequence
# (original prompt + intermediate output + new prompt tokens)
updated_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prefill_token_ids
=
[
1
,
2
,
3
,
10
,
4
,
5
],
mm_features
=
[],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
,
1
],),
num_computed_tokens
=
4
,
# 3 original prompt + 1 intermediate output
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
updated_req_data
]))
# Step 3: Verify no free_indices leak (old slot recycled)
assert
len
(
req_states
.
free_indices
)
==
initial_free
-
1
# Verify the request is still tracked with exactly one index
assert
req_id
in
req_states
.
req_id_to_index
assert
sum
(
1
for
v
in
req_states
.
index_to_req_id
.
values
()
if
v
==
req_id
)
==
1
# Verify state was updated with new values
new_idx
=
req_states
.
req_id_to_index
[
req_id
]
assert
req_states
.
prompt_len
.
np
[
new_idx
]
==
3
assert
req_states
.
prefill_len
.
np
[
new_idx
]
==
6
assert
req_states
.
num_computed_prefill_tokens
[
new_idx
]
==
4
# Verify model_state and block_tables were re-registered
runner
.
model_state
.
add_request
.
assert_called_with
(
new_idx
,
updated_req_data
)
runner
.
block_tables
.
append_block_ids
.
assert_called_with
(
new_idx
,
([
0
,
1
],),
overwrite
=
True
)
def
test_e2e_streaming_with_multimodal_features
(
mock_model_runner_with_req_states
,
):
"""Test that streaming sessions with multimodal features are updated.
This test validates that when a streaming session with mm features
is updated:
1. The old request state is removed (no free_indices leak)
2. encoder_cache is cleaned up and re-registered with new mm_features
3. model_state is re-registered (recomputes M-RoPE positions etc.)
"""
runner
=
mock_model_runner_with_req_states
req_states
=
runner
.
req_states
req_id
=
"streaming_mm_req_0"
initial_free
=
len
(
req_states
.
free_indices
)
# Enable encoder_cache for multimodal
runner
.
encoder_cache
=
Mock
()
# Step 1: Add initial request with one audio feature
mm_feature_1
=
Mock
()
initial_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
],
prefill_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
],
mm_features
=
[
mm_feature_1
],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
],),
num_computed_tokens
=
14
,
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
initial_req_data
]))
assert
req_id
in
req_states
.
req_id_to_index
# Reset mocks to track only the streaming update calls
runner
.
encoder_cache
.
reset_mock
()
runner
.
model_state
.
reset_mock
()
# Step 2: Create streaming update with additional multimodal feature
# The scheduler has folded the intermediate output (100) into
# prefill_token_ids and added a new audio chunk
mm_feature_2
=
Mock
()
updated_req_data
=
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
],
prefill_token_ids
=
[
1
,
2
]
+
[
0
]
*
10
+
[
3
,
4
,
100
]
+
[
0
]
*
5
+
[
5
],
mm_features
=
[
mm_feature_1
,
mm_feature_2
],
sampling_params
=
None
,
pooling_params
=
None
,
block_ids
=
([
0
,
1
],),
num_computed_tokens
=
14
,
lora_request
=
None
,
)
runner
.
add_requests
(
_make_scheduler_output
([
updated_req_data
]))
# Step 3: Verify no free_indices leak
assert
len
(
req_states
.
free_indices
)
==
initial_free
-
1
assert
sum
(
1
for
v
in
req_states
.
index_to_req_id
.
values
()
if
v
==
req_id
)
==
1
# Verify encoder_cache was cleaned up and re-registered
runner
.
encoder_cache
.
remove_request
.
assert_called_once_with
(
req_id
)
runner
.
encoder_cache
.
add_request
.
assert_called_once_with
(
req_id
,
[
mm_feature_1
,
mm_feature_2
]
)
# Verify model_state was re-registered with new data
new_idx
=
req_states
.
req_id_to_index
[
req_id
]
runner
.
model_state
.
add_request
.
assert_called_once_with
(
new_idx
,
updated_req_data
)
# Verify updated prefill length
assert
req_states
.
prefill_len
.
np
[
new_idx
]
==
21
vllm/model_executor/models/whisper_causal.py
View file @
85f671b8
...
@@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling(
...
@@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling(
new_common_attn_metadata
.
query_start_loc
*=
block_pool_size
new_common_attn_metadata
.
query_start_loc
*=
block_pool_size
new_common_attn_metadata
.
query_start_loc_cpu
*=
block_pool_size
new_common_attn_metadata
.
query_start_loc_cpu
*=
block_pool_size
new_common_attn_metadata
.
seq_lens
*=
block_pool_size
new_common_attn_metadata
.
seq_lens
*=
block_pool_size
new_common_attn_metadata
.
_seq_lens_cpu
*=
block_pool_size
if
new_common_attn_metadata
.
_seq_lens_cpu
is
not
None
:
new_common_attn_metadata
.
_num_computed_tokens_cpu
*=
block_pool_size
new_common_attn_metadata
.
_seq_lens_cpu
*=
block_pool_size
if
new_common_attn_metadata
.
_num_computed_tokens_cpu
is
not
None
:
new_common_attn_metadata
.
_num_computed_tokens_cpu
*=
block_pool_size
new_common_attn_metadata
.
num_actual_tokens
*=
block_pool_size
new_common_attn_metadata
.
num_actual_tokens
*=
block_pool_size
new_common_attn_metadata
.
max_query_len
*=
block_pool_size
new_common_attn_metadata
.
max_query_len
*=
block_pool_size
new_common_attn_metadata
.
max_seq_len
*=
block_pool_size
new_common_attn_metadata
.
max_seq_len
*=
block_pool_size
...
...
vllm/v1/worker/gpu/attn_utils.py
View file @
85f671b8
...
@@ -111,6 +111,7 @@ def _reshape_kv_cache(
...
@@ -111,6 +111,7 @@ def _reshape_kv_cache(
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
],
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
AttentionBackend
],
attn_backends
:
dict
[
str
,
AttentionBackend
],
cache_dtype
:
str
,
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
for
kv_cache_group_spec
in
kv_cache_config
.
kv_cache_groups
:
...
@@ -127,6 +128,7 @@ def _reshape_kv_cache(
...
@@ -127,6 +128,7 @@ def _reshape_kv_cache(
kv_cache_spec
.
block_size
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
kv_cache_spec
.
head_size
,
cache_dtype
,
)
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
...
@@ -155,9 +157,12 @@ def init_kv_cache(
...
@@ -155,9 +157,12 @@ def init_kv_cache(
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
attn_backends
:
dict
[
str
,
AttentionBackend
],
attn_backends
:
dict
[
str
,
AttentionBackend
],
device
:
torch
.
device
,
device
:
torch
.
device
,
cache_dtype
:
str
,
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
kv_cache_raw_tensors
=
_allocate_kv_cache
(
kv_cache_config
,
device
)
kv_cache_raw_tensors
=
_allocate_kv_cache
(
kv_cache_config
,
device
)
kv_caches
=
_reshape_kv_cache
(
kv_cache_config
,
kv_cache_raw_tensors
,
attn_backends
)
kv_caches
=
_reshape_kv_cache
(
kv_cache_config
,
kv_cache_raw_tensors
,
attn_backends
,
cache_dtype
)
bind_kv_cache
(
kv_caches
,
forward_context
,
runner_kv_caches
)
bind_kv_cache
(
kv_caches
,
forward_context
,
runner_kv_caches
)
return
kv_caches
return
kv_caches
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
85f671b8
...
@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_config
,
self
.
kv_cache_config
,
self
.
attn_backends
,
self
.
attn_backends
,
self
.
device
,
self
.
device
,
self
.
cache_config
.
cache_dtype
,
)
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
...
@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
return
cuda_graph_size
return
cuda_graph_size
def
_remove_request
(
self
,
req_id
:
str
)
->
bool
:
if
not
self
.
req_states
.
remove_request
(
req_id
):
return
False
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_cache
.
remove_request
(
req_id
)
if
self
.
prompt_logprobs_worker
is
not
None
:
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
return
True
def
finish_requests
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
def
finish_requests
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
finished_req_ids
=
scheduler_output
.
finished_req_ids
finished_req_ids
=
scheduler_output
.
finished_req_ids
preempted_req_ids
=
scheduler_output
.
preempted_req_ids
preempted_req_ids
=
scheduler_output
.
preempted_req_ids
if
preempted_req_ids
:
if
preempted_req_ids
:
finished_req_ids
=
finished_req_ids
.
union
(
preempted_req_ids
)
finished_req_ids
=
finished_req_ids
.
union
(
preempted_req_ids
)
for
req_id
in
finished_req_ids
:
for
req_id
in
finished_req_ids
:
self
.
req_states
.
remove_request
(
req_id
)
self
.
_remove_request
(
req_id
)
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_cache
.
remove_request
(
req_id
)
if
self
.
prompt_logprobs_worker
is
not
None
:
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
if
self
.
encoder_cache
is
not
None
:
if
self
.
encoder_cache
is
not
None
:
...
@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
new_req_data
.
prompt_token_ids
is
not
None
assert
new_req_data
.
prompt_token_ids
is
not
None
assert
new_req_data
.
prefill_token_ids
is
not
None
assert
new_req_data
.
prefill_token_ids
is
not
None
req_id
=
new_req_data
.
req_id
req_id
=
new_req_data
.
req_id
# Streaming input update: request already exists from a prior
# chunk. Remove old state so it can be cleanly re-added below
# with the updated prompt_token_ids and mm_features.
self
.
_remove_request
(
req_id
)
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
self
.
req_states
.
add_request
(
self
.
req_states
.
add_request
(
req_id
=
req_id
,
req_id
=
req_id
,
...
...
vllm/v1/worker/gpu/model_states/default.py
View file @
85f671b8
...
@@ -7,6 +7,7 @@ import torch.nn as nn
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.tasks
import
GenerationTask
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
...
@@ -61,6 +62,28 @@ class DefaultModelState(ModelState):
...
@@ -61,6 +62,28 @@ class DefaultModelState(ModelState):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
from
vllm.model_executor.models.interfaces
import
(
supports_realtime
,
supports_transcription
,
)
from
vllm.model_executor.models.interfaces_base
import
is_text_generation_model
supported_tasks
=
list
[
GenerationTask
]()
if
is_text_generation_model
(
self
.
model
):
supported_tasks
.
append
(
"generate"
)
if
supports_transcription
(
self
.
model
):
if
self
.
model
.
supports_transcription_only
:
return
(
"transcription"
,)
supported_tasks
.
append
(
"transcription"
)
if
supports_realtime
(
self
.
model
):
supported_tasks
.
append
(
"realtime"
)
return
tuple
(
supported_tasks
)
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
if
self
.
rope_state
is
not
None
:
if
self
.
rope_state
is
not
None
:
assert
new_req_data
.
prefill_token_ids
is
not
None
assert
new_req_data
.
prefill_token_ids
is
not
None
...
...
vllm/v1/worker/gpu/model_states/interface.py
View file @
85f671b8
...
@@ -28,8 +28,9 @@ class ModelState(ABC):
...
@@ -28,8 +28,9 @@ class ModelState(ABC):
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
r
eturn
(
"generate"
,)
r
aise
NotImplementedError
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
return
None
return
None
...
...
vllm/v1/worker/gpu/states.py
View file @
85f671b8
...
@@ -109,13 +109,14 @@ class RequestState:
...
@@ -109,13 +109,14 @@ class RequestState:
self
.
all_token_ids
.
apply_write
()
self
.
all_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
def
remove_request
(
self
,
req_id
:
str
)
->
bool
:
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_idx
is
None
:
if
req_idx
is
None
:
# Request not found.
# Request not found.
return
return
False
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
free_indices
.
append
(
req_idx
)
self
.
free_indices
.
append
(
req_idx
)
return
True
def
any_prefills
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
def
any_prefills
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
bool
:
return
np
.
any
(
return
np
.
any
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment