Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55eed6b7
Unverified
Commit
55eed6b7
authored
Mar 11, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 11, 2026
Browse files
[Model Runner V2] Add WhisperModelState [6/N] (#35790)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
c77181e5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
232 additions
and
20 deletions
+232
-20
.buildkite/test_areas/model_runner_v2.yaml
.buildkite/test_areas/model_runner_v2.yaml
+1
-2
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+6
-0
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+32
-5
vllm/v1/worker/gpu/model_states/__init__.py
vllm/v1/worker/gpu/model_states/__init__.py
+5
-0
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+3
-4
vllm/v1/worker/gpu/model_states/interface.py
vllm/v1/worker/gpu/model_states/interface.py
+10
-9
vllm/v1/worker/gpu/model_states/whisper.py
vllm/v1/worker/gpu/model_states/whisper.py
+174
-0
No files found.
.buildkite/test_areas/model_runner_v2.yaml
View file @
55eed6b7
...
...
@@ -47,8 +47,7 @@ steps:
-
python3 offline_inference/audio_language.py --seed
0
-
python3 offline_inference/vision_language.py --seed
0
-
python3 offline_inference/vision_language_multi_image.py --seed
0
# TODO: uncomment once https://github.com/vllm-project/vllm/pull/35790 is merged.
#- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 # TODO
-
python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed
0
# for pooling models
-
python3 pooling/embed/vision_embedding_offline.py --seed
0
# for features demo
...
...
vllm/v1/worker/gpu/attn_utils.py
View file @
55eed6b7
...
...
@@ -3,6 +3,7 @@
from
collections.abc
import
Sequence
from
typing
import
Any
,
cast
import
numpy
as
np
import
torch
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
...
...
@@ -180,6 +181,7 @@ def build_attn_metadata(
slot_mappings
:
torch
.
Tensor
,
kv_cache_config
:
KVCacheConfig
,
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
encoder_seq_lens
:
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]
|
None
=
None
,
)
->
dict
[
str
,
Any
]:
seq_lens
=
seq_lens
[:
num_reqs
]
if
dcp_local_seq_lens
is
not
None
:
...
...
@@ -204,6 +206,10 @@ def build_attn_metadata(
causal
=
True
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
if
encoder_seq_lens
and
i
in
encoder_seq_lens
:
encoder_seq_lens_gpu
,
encoder_seq_lens_cpu
=
encoder_seq_lens
[
i
]
common_attn_metadata
.
encoder_seq_lens
=
encoder_seq_lens_gpu
common_attn_metadata
.
encoder_seq_lens_cpu
=
encoder_seq_lens_cpu
for
attn_group
in
attn_groups
[
i
]:
attn_metadata_builder
=
attn_group
.
get_metadata_builder
(
0
)
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
55eed6b7
...
...
@@ -389,5 +389,6 @@ def prepare_inputs_to_capture(
slot_mappings
,
attn_groups
,
kv_cache_config
,
for_capture
=
True
,
)
return
attn_metadata
,
slot_mappings_by_layer
vllm/v1/worker/gpu/model_runner.py
View file @
55eed6b7
...
...
@@ -125,6 +125,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
is_encoder_decoder
=
self
.
model_config
.
is_encoder_decoder
self
.
use_async_scheduling
=
self
.
scheduler_config
.
async_scheduling
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
...
...
@@ -159,12 +160,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
:
self
.
encoder_cache
=
EncoderCache
()
# Speculative decoding.
self
.
speculator
=
None
self
.
num_speculative_steps
=
0
self
.
use_aux_hidden_state_outputs
=
False
use_strict_rejection_sampling
=
False
if
self
.
speculative_config
is
not
None
:
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
use_strict_rejection_sampling
=
(
self
.
speculative_config
.
rejection_sample_method
==
"strict"
)
if
self
.
is_last_pp_rank
:
self
.
speculator
=
init_speculator
(
self
.
vllm_config
,
self
.
device
)
...
...
@@ -173,13 +179,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_aux_hidden_state_outputs
=
True
if
self
.
pp_size
>
1
:
raise
ValueError
(
"EAGLE3 with pipeline parallel is not supported."
)
use_strict_rejection_sampling
=
(
self
.
speculative_config
.
rejection_sample_method
==
"strict"
)
# Draft tokens propagation - for spec-dec + struct outputs.
self
.
draft_tokens_handler
=
DraftTokensHandler
(
self
.
device
)
# General request states.
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
...
...
@@ -243,7 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
tasks
:
list
[
SupportedTask
]
=
[]
if
self
.
model_config
.
runner_type
==
"generate"
:
tasks
.
append
(
"generate"
)
tasks
.
extend
(
self
.
model_state
.
get_supported_generation_tasks
()
)
if
self
.
pooling_runner
is
not
None
:
tasks
.
extend
(
self
.
pooling_runner
.
get_supported_pooling_tasks
())
return
tuple
(
tasks
)
...
...
@@ -307,11 +311,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
]
block_table_max_model_len
=
self
.
max_model_len
if
self
.
is_encoder_decoder
:
# Cross-attention block tables need to index encoder tokens
# (e.g., Whisper ~1500), which can exceed decoder max_model_len.
block_table_max_model_len
=
max
(
block_table_max_model_len
,
getattr
(
self
.
model_config
.
hf_config
,
"max_source_positions"
,
0
),
)
self
.
block_tables
=
BlockTables
(
block_sizes
=
block_sizes
,
max_num_reqs
=
self
.
max_num_reqs
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
block_table_
max_model_len
,
device
=
self
.
device
,
cp_size
=
self
.
dcp_size
,
cp_rank
=
self
.
dcp_rank
,
...
...
@@ -870,6 +883,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
num_tokens_across_dp
=
None
skip_compiled
=
False
if
self
.
is_encoder_decoder
and
scheduler_output
.
scheduled_encoder_inputs
:
# Encoder-decoder models such as Whisper should run eager/non-compiled
# when encoder inputs are scheduled, because this step updates
# cross-attention cache with dynamic encoder outputs.
# Override batch_desc to NONE.
skip_compiled
=
True
batch_desc
=
BatchExecutionDescriptor
(
cg_mode
=
CUDAGraphMode
.
NONE
,
num_tokens
=
num_toks
,
num_reqs
=
num_reqs
,
)
if
self
.
dp_size
>
1
:
batch_desc
,
num_tokens_across_dp
=
sync_cudagraph_and_dp_padding
(
self
.
cudagraph_manager
,
...
...
@@ -984,6 +1010,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
slot_mappings_by_layer
,
skip_compiled
=
skip_compiled
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
model_output
=
self
.
model
(
**
model_inputs
)
...
...
vllm/v1/worker/gpu/model_states/__init__.py
View file @
55eed6b7
...
...
@@ -13,6 +13,11 @@ def init_model_state(
encoder_cache
:
EncoderCache
|
None
,
device
:
torch
.
device
,
):
if
"WhisperForConditionalGeneration"
in
vllm_config
.
model_config
.
architectures
:
from
vllm.v1.worker.gpu.model_states.whisper
import
WhisperModelState
return
WhisperModelState
(
vllm_config
,
model
,
encoder_cache
,
device
)
from
vllm.v1.worker.gpu.model_states.default
import
DefaultModelState
return
DefaultModelState
(
vllm_config
,
model
,
encoder_cache
,
device
)
vllm/v1/worker/gpu/model_states/default.py
View file @
55eed6b7
...
...
@@ -109,7 +109,7 @@ class DefaultModelState(ModelState):
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
)
->
dict
[
str
,
Any
]:
if
not
self
.
uses_mrope
:
# Common case (1D positions).
return
{}
...
...
@@ -126,9 +126,7 @@ class DefaultModelState(ModelState):
]
return
{
"positions"
:
mrope_positions
}
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
model_inputs
=
{}
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
[:
num_tokens
]
...
...
@@ -146,6 +144,7 @@ class DefaultModelState(ModelState):
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
for_capture
:
bool
=
False
,
)
->
dict
[
str
,
Any
]:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
...
...
vllm/v1/worker/gpu/model_states/interface.py
View file @
55eed6b7
...
...
@@ -8,6 +8,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.tasks
import
GenerationTask
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
...
...
@@ -27,13 +28,14 @@ class ModelState(ABC):
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
return
(
"generate"
,)
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
r
aise
NotImplementedError
r
eturn
None
@
abstractmethod
def
apply_staged_writes
(
self
)
->
None
:
r
aise
NotImplementedError
r
eturn
None
@
abstractmethod
def
get_mm_embeddings
(
...
...
@@ -41,19 +43,17 @@ class ModelState(ABC):
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
req_states
:
RequestState
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
None
:
raise
NotImplementedError
@
abstractmethod
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
)
->
dict
[
str
,
Any
]:
raise
NotImplementedError
@
abstractmethod
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
raise
NotImplementedError
@
abstractmethod
...
...
@@ -65,5 +65,6 @@ class ModelState(ABC):
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
for_capture
:
bool
=
False
,
)
->
dict
[
str
,
Any
]:
raise
NotImplementedError
vllm/v1/worker/gpu/model_states/whisper.py
0 → 100644
View file @
55eed6b7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.kv_cache_interface
import
CrossAttentionSpec
,
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.encoder_cache
import
EncoderCache
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.utils
import
AttentionGroup
class
WhisperModelState
(
ModelState
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
encoder_cache
:
EncoderCache
|
None
,
device
:
torch
.
device
,
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
model
=
model
self
.
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
device
=
device
assert
encoder_cache
is
not
None
self
.
encoder_cache
=
encoder_cache
self
.
encoder_runner
=
EncoderRunner
(
model
=
self
.
model
,
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
model_config
.
get_inputs_embeds_size
(),
encoder_cache
=
self
.
encoder_cache
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
self
.
max_encoder_len
=
getattr
(
self
.
model_config
.
hf_config
,
"max_source_positions"
,
self
.
max_model_len
,
)
self
.
encoder_seq_lens_gpu
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
encoder_outputs
:
list
[
torch
.
Tensor
]
=
[]
def
get_supported_generation_tasks
(
self
):
return
(
"transcription"
,)
def
get_mm_embeddings
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
req_states
:
RequestState
,
)
->
None
:
# Ensure encoder inputs are ordered consistently with input_batch.req_ids.
encoder_inputs
:
dict
[
str
,
list
[
int
]]
=
{}
for
req_id
in
input_batch
.
req_ids
:
req_encoder_inputs
=
scheduled_encoder_inputs
.
get
(
req_id
,
[])
if
req_encoder_inputs
:
encoder_inputs
[
req_id
]
=
req_encoder_inputs
_
,
mm_kwargs
=
self
.
encoder_runner
.
prepare_mm_inputs
(
encoder_inputs
)
if
mm_kwargs
:
# Whisper consumes encoder outputs through `encoder_outputs`, not
# `inputs_embeds`. Single modality (audio) so execute_mm_encoder
# preserves request order; use its return value directly.
# No need to store in encoder_cache: cross-attention K/V are written
# to the KV cache on the first step; decode steps use the cache.
self
.
encoder_outputs
=
self
.
encoder_runner
.
execute_mm_encoder
(
mm_kwargs
)
else
:
# Decode steps: encoder K/V are in cross-attention KV cache.
self
.
encoder_outputs
=
[]
return
None
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
Any
]:
model_inputs
=
{
"encoder_outputs"
:
self
.
encoder_outputs
}
self
.
encoder_outputs
=
[]
return
model_inputs
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
return
{
"encoder_outputs"
:
[]}
def
prepare_attn
(
self
,
input_batch
:
InputBatch
,
cudagraph_mode
:
CUDAGraphMode
,
block_tables
:
tuple
[
torch
.
Tensor
,
...],
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
for_capture
:
bool
=
False
,
)
->
dict
[
str
,
Any
]:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
num_reqs
=
input_batch
.
num_reqs_after_padding
num_tokens
=
input_batch
.
num_tokens_after_padding
else
:
num_reqs
=
input_batch
.
num_reqs
num_tokens
=
input_batch
.
num_tokens
encoder_seq_lens
=
self
.
_get_encoder_seq_lens
(
input_batch
.
req_ids
,
attn_groups
,
for_capture
)
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
)
return
attn_metadata
def
_get_encoder_seq_lens
(
self
,
req_ids
:
list
[
str
],
attn_groups
:
list
[
list
[
AttentionGroup
]],
for_capture
:
bool
,
)
->
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]:
num_reqs
=
len
(
req_ids
)
encoder_seq_lens_np
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
if
not
for_capture
:
# During normal execution, use actual encoder lengths.
for
i
,
req_id
in
enumerate
(
req_ids
):
mm_features
=
self
.
encoder_cache
.
mm_features
.
get
(
req_id
,
[])
encoder_seq_lens_np
[
i
]
=
sum
(
feature
.
mm_position
.
get_num_embeds
()
for
feature
in
mm_features
)
else
:
# During CUDA graph capture, use max encoder length so max_seqlen_k
# is captured with the correct value for cross-attention.
encoder_seq_lens_np
[:]
=
self
.
max_encoder_len
self
.
encoder_seq_lens_gpu
[:
num_reqs
].
copy_
(
torch
.
from_numpy
(
encoder_seq_lens_np
),
non_blocking
=
True
)
self
.
encoder_seq_lens_gpu
[
num_reqs
:].
fill_
(
0
)
encoder_seq_lens_gpu
=
self
.
encoder_seq_lens_gpu
[:
num_reqs
]
seq_lens_by_group
:
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]
=
{}
for
kv_cache_group_idx
,
groups
in
enumerate
(
attn_groups
):
has_cross_attn
=
any
(
isinstance
(
attn_group
.
kv_cache_spec
,
CrossAttentionSpec
)
for
attn_group
in
groups
)
if
has_cross_attn
:
seq_lens_by_group
[
kv_cache_group_idx
]
=
(
encoder_seq_lens_gpu
,
encoder_seq_lens_np
,
)
return
seq_lens_by_group
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment