Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55eed6b7
Unverified
Commit
55eed6b7
authored
Mar 11, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 11, 2026
Browse files
[Model Runner V2] Add WhisperModelState [6/N] (#35790)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
c77181e5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
232 additions
and
20 deletions
+232
-20
.buildkite/test_areas/model_runner_v2.yaml
.buildkite/test_areas/model_runner_v2.yaml
+1
-2
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+6
-0
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+32
-5
vllm/v1/worker/gpu/model_states/__init__.py
vllm/v1/worker/gpu/model_states/__init__.py
+5
-0
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+3
-4
vllm/v1/worker/gpu/model_states/interface.py
vllm/v1/worker/gpu/model_states/interface.py
+10
-9
vllm/v1/worker/gpu/model_states/whisper.py
vllm/v1/worker/gpu/model_states/whisper.py
+174
-0
No files found.
.buildkite/test_areas/model_runner_v2.yaml
View file @
55eed6b7
...
@@ -47,8 +47,7 @@ steps:
...
@@ -47,8 +47,7 @@ steps:
-
python3 offline_inference/audio_language.py --seed
0
-
python3 offline_inference/audio_language.py --seed
0
-
python3 offline_inference/vision_language.py --seed
0
-
python3 offline_inference/vision_language.py --seed
0
-
python3 offline_inference/vision_language_multi_image.py --seed
0
-
python3 offline_inference/vision_language_multi_image.py --seed
0
# TODO: uncomment once https://github.com/vllm-project/vllm/pull/35790 is merged.
-
python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed
0
#- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 # TODO
# for pooling models
# for pooling models
-
python3 pooling/embed/vision_embedding_offline.py --seed
0
-
python3 pooling/embed/vision_embedding_offline.py --seed
0
# for features demo
# for features demo
...
...
vllm/v1/worker/gpu/attn_utils.py
View file @
55eed6b7
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Any
,
cast
from
typing
import
Any
,
cast
import
numpy
as
np
import
torch
import
torch
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
...
@@ -180,6 +181,7 @@ def build_attn_metadata(
...
@@ -180,6 +181,7 @@ def build_attn_metadata(
slot_mappings
:
torch
.
Tensor
,
slot_mappings
:
torch
.
Tensor
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
encoder_seq_lens
:
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]
|
None
=
None
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
seq_lens
=
seq_lens
[:
num_reqs
]
seq_lens
=
seq_lens
[:
num_reqs
]
if
dcp_local_seq_lens
is
not
None
:
if
dcp_local_seq_lens
is
not
None
:
...
@@ -204,6 +206,10 @@ def build_attn_metadata(
...
@@ -204,6 +206,10 @@ def build_attn_metadata(
causal
=
True
,
causal
=
True
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
)
if
encoder_seq_lens
and
i
in
encoder_seq_lens
:
encoder_seq_lens_gpu
,
encoder_seq_lens_cpu
=
encoder_seq_lens
[
i
]
common_attn_metadata
.
encoder_seq_lens
=
encoder_seq_lens_gpu
common_attn_metadata
.
encoder_seq_lens_cpu
=
encoder_seq_lens_cpu
for
attn_group
in
attn_groups
[
i
]:
for
attn_group
in
attn_groups
[
i
]:
attn_metadata_builder
=
attn_group
.
get_metadata_builder
(
0
)
attn_metadata_builder
=
attn_group
.
get_metadata_builder
(
0
)
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
55eed6b7
...
@@ -389,5 +389,6 @@ def prepare_inputs_to_capture(
...
@@ -389,5 +389,6 @@ def prepare_inputs_to_capture(
slot_mappings
,
slot_mappings
,
attn_groups
,
attn_groups
,
kv_cache_config
,
kv_cache_config
,
for_capture
=
True
,
)
)
return
attn_metadata
,
slot_mappings_by_layer
return
attn_metadata
,
slot_mappings_by_layer
vllm/v1/worker/gpu/model_runner.py
View file @
55eed6b7
...
@@ -125,6 +125,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -125,6 +125,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
is_encoder_decoder
=
self
.
model_config
.
is_encoder_decoder
self
.
use_async_scheduling
=
self
.
scheduler_config
.
async_scheduling
self
.
use_async_scheduling
=
self
.
scheduler_config
.
async_scheduling
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
...
@@ -159,12 +160,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -159,12 +160,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
:
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
:
self
.
encoder_cache
=
EncoderCache
()
self
.
encoder_cache
=
EncoderCache
()
# Speculative decoding.
self
.
speculator
=
None
self
.
speculator
=
None
self
.
num_speculative_steps
=
0
self
.
num_speculative_steps
=
0
self
.
use_aux_hidden_state_outputs
=
False
self
.
use_aux_hidden_state_outputs
=
False
use_strict_rejection_sampling
=
False
use_strict_rejection_sampling
=
False
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_speculative_steps
=
self
.
speculative_config
.
num_speculative_tokens
use_strict_rejection_sampling
=
(
self
.
speculative_config
.
rejection_sample_method
==
"strict"
)
if
self
.
is_last_pp_rank
:
if
self
.
is_last_pp_rank
:
self
.
speculator
=
init_speculator
(
self
.
vllm_config
,
self
.
device
)
self
.
speculator
=
init_speculator
(
self
.
vllm_config
,
self
.
device
)
...
@@ -173,13 +179,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -173,13 +179,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_aux_hidden_state_outputs
=
True
self
.
use_aux_hidden_state_outputs
=
True
if
self
.
pp_size
>
1
:
if
self
.
pp_size
>
1
:
raise
ValueError
(
"EAGLE3 with pipeline parallel is not supported."
)
raise
ValueError
(
"EAGLE3 with pipeline parallel is not supported."
)
use_strict_rejection_sampling
=
(
self
.
speculative_config
.
rejection_sample_method
==
"strict"
)
# Draft tokens propagation - for spec-dec + struct outputs.
# Draft tokens propagation - for spec-dec + struct outputs.
self
.
draft_tokens_handler
=
DraftTokensHandler
(
self
.
device
)
self
.
draft_tokens_handler
=
DraftTokensHandler
(
self
.
device
)
# General request states.
self
.
req_states
=
RequestState
(
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
...
@@ -243,7 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -243,7 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
tasks
:
list
[
SupportedTask
]
=
[]
tasks
:
list
[
SupportedTask
]
=
[]
if
self
.
model_config
.
runner_type
==
"generate"
:
if
self
.
model_config
.
runner_type
==
"generate"
:
tasks
.
append
(
"generate"
)
tasks
.
extend
(
self
.
model_state
.
get_supported_generation_tasks
()
)
if
self
.
pooling_runner
is
not
None
:
if
self
.
pooling_runner
is
not
None
:
tasks
.
extend
(
self
.
pooling_runner
.
get_supported_pooling_tasks
())
tasks
.
extend
(
self
.
pooling_runner
.
get_supported_pooling_tasks
())
return
tuple
(
tasks
)
return
tuple
(
tasks
)
...
@@ -307,11 +311,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -307,11 +311,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
]
]
block_table_max_model_len
=
self
.
max_model_len
if
self
.
is_encoder_decoder
:
# Cross-attention block tables need to index encoder tokens
# (e.g., Whisper ~1500), which can exceed decoder max_model_len.
block_table_max_model_len
=
max
(
block_table_max_model_len
,
getattr
(
self
.
model_config
.
hf_config
,
"max_source_positions"
,
0
),
)
self
.
block_tables
=
BlockTables
(
self
.
block_tables
=
BlockTables
(
block_sizes
=
block_sizes
,
block_sizes
=
block_sizes
,
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
block_table_
max_model_len
,
device
=
self
.
device
,
device
=
self
.
device
,
cp_size
=
self
.
dcp_size
,
cp_size
=
self
.
dcp_size
,
cp_rank
=
self
.
dcp_rank
,
cp_rank
=
self
.
dcp_rank
,
...
@@ -870,6 +883,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -870,6 +883,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
num_tokens_across_dp
=
None
num_tokens_across_dp
=
None
skip_compiled
=
False
if
self
.
is_encoder_decoder
and
scheduler_output
.
scheduled_encoder_inputs
:
# Encoder-decoder models such as Whisper should run eager/non-compiled
# when encoder inputs are scheduled, because this step updates
# cross-attention cache with dynamic encoder outputs.
# Override batch_desc to NONE.
skip_compiled
=
True
batch_desc
=
BatchExecutionDescriptor
(
cg_mode
=
CUDAGraphMode
.
NONE
,
num_tokens
=
num_toks
,
num_reqs
=
num_reqs
,
)
if
self
.
dp_size
>
1
:
if
self
.
dp_size
>
1
:
batch_desc
,
num_tokens_across_dp
=
sync_cudagraph_and_dp_padding
(
batch_desc
,
num_tokens_across_dp
=
sync_cudagraph_and_dp_padding
(
self
.
cudagraph_manager
,
self
.
cudagraph_manager
,
...
@@ -984,6 +1010,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -984,6 +1010,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
slot_mappings_by_layer
,
slot_mapping
=
slot_mappings_by_layer
,
skip_compiled
=
skip_compiled
,
):
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
model_output
=
self
.
model
(
**
model_inputs
)
model_output
=
self
.
model
(
**
model_inputs
)
...
...
vllm/v1/worker/gpu/model_states/__init__.py
View file @
55eed6b7
...
@@ -13,6 +13,11 @@ def init_model_state(
...
@@ -13,6 +13,11 @@ def init_model_state(
encoder_cache
:
EncoderCache
|
None
,
encoder_cache
:
EncoderCache
|
None
,
device
:
torch
.
device
,
device
:
torch
.
device
,
):
):
if
"WhisperForConditionalGeneration"
in
vllm_config
.
model_config
.
architectures
:
from
vllm.v1.worker.gpu.model_states.whisper
import
WhisperModelState
return
WhisperModelState
(
vllm_config
,
model
,
encoder_cache
,
device
)
from
vllm.v1.worker.gpu.model_states.default
import
DefaultModelState
from
vllm.v1.worker.gpu.model_states.default
import
DefaultModelState
return
DefaultModelState
(
vllm_config
,
model
,
encoder_cache
,
device
)
return
DefaultModelState
(
vllm_config
,
model
,
encoder_cache
,
device
)
vllm/v1/worker/gpu/model_states/default.py
View file @
55eed6b7
...
@@ -109,7 +109,7 @@ class DefaultModelState(ModelState):
...
@@ -109,7 +109,7 @@ class DefaultModelState(ModelState):
def
prepare_inputs
(
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
)
->
dict
[
str
,
Any
]:
if
not
self
.
uses_mrope
:
if
not
self
.
uses_mrope
:
# Common case (1D positions).
# Common case (1D positions).
return
{}
return
{}
...
@@ -126,9 +126,7 @@ class DefaultModelState(ModelState):
...
@@ -126,9 +126,7 @@ class DefaultModelState(ModelState):
]
]
return
{
"positions"
:
mrope_positions
}
return
{
"positions"
:
mrope_positions
}
def
prepare_dummy_inputs
(
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
model_inputs
=
{}
model_inputs
=
{}
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
[:
num_tokens
]
...
@@ -146,6 +144,7 @@ class DefaultModelState(ModelState):
...
@@ -146,6 +144,7 @@ class DefaultModelState(ModelState):
slot_mappings
:
torch
.
Tensor
,
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
for_capture
:
bool
=
False
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
# Use padded sizes - padding is handled by model_runner.prepare_attn.
...
...
vllm/v1/worker/gpu/model_states/interface.py
View file @
55eed6b7
...
@@ -8,6 +8,7 @@ import torch.nn as nn
...
@@ -8,6 +8,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.tasks
import
GenerationTask
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
...
@@ -27,13 +28,14 @@ class ModelState(ABC):
...
@@ -27,13 +28,14 @@ class ModelState(ABC):
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
get_supported_generation_tasks
(
self
)
->
tuple
[
GenerationTask
,
...]:
return
(
"generate"
,)
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
r
aise
NotImplementedError
r
eturn
None
@
abstractmethod
def
apply_staged_writes
(
self
)
->
None
:
def
apply_staged_writes
(
self
)
->
None
:
r
aise
NotImplementedError
r
eturn
None
@
abstractmethod
@
abstractmethod
def
get_mm_embeddings
(
def
get_mm_embeddings
(
...
@@ -41,19 +43,17 @@ class ModelState(ABC):
...
@@ -41,19 +43,17 @@ class ModelState(ABC):
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
req_states
:
RequestState
,
req_states
:
RequestState
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
prepare_inputs
(
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
)
->
dict
[
str
,
Any
]:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
prepare_dummy_inputs
(
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
@@ -65,5 +65,6 @@ class ModelState(ABC):
...
@@ -65,5 +65,6 @@ class ModelState(ABC):
slot_mappings
:
torch
.
Tensor
,
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
for_capture
:
bool
=
False
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
raise
NotImplementedError
raise
NotImplementedError
vllm/v1/worker/gpu/model_states/whisper.py
0 → 100644
View file @
55eed6b7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.kv_cache_interface
import
CrossAttentionSpec
,
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.encoder_cache
import
EncoderCache
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.utils
import
AttentionGroup
class
WhisperModelState
(
ModelState
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
encoder_cache
:
EncoderCache
|
None
,
device
:
torch
.
device
,
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
model
=
model
self
.
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
device
=
device
assert
encoder_cache
is
not
None
self
.
encoder_cache
=
encoder_cache
self
.
encoder_runner
=
EncoderRunner
(
model
=
self
.
model
,
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
model_config
.
get_inputs_embeds_size
(),
encoder_cache
=
self
.
encoder_cache
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
self
.
max_encoder_len
=
getattr
(
self
.
model_config
.
hf_config
,
"max_source_positions"
,
self
.
max_model_len
,
)
self
.
encoder_seq_lens_gpu
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
encoder_outputs
:
list
[
torch
.
Tensor
]
=
[]
def
get_supported_generation_tasks
(
self
):
return
(
"transcription"
,)
def
get_mm_embeddings
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
req_states
:
RequestState
,
)
->
None
:
# Ensure encoder inputs are ordered consistently with input_batch.req_ids.
encoder_inputs
:
dict
[
str
,
list
[
int
]]
=
{}
for
req_id
in
input_batch
.
req_ids
:
req_encoder_inputs
=
scheduled_encoder_inputs
.
get
(
req_id
,
[])
if
req_encoder_inputs
:
encoder_inputs
[
req_id
]
=
req_encoder_inputs
_
,
mm_kwargs
=
self
.
encoder_runner
.
prepare_mm_inputs
(
encoder_inputs
)
if
mm_kwargs
:
# Whisper consumes encoder outputs through `encoder_outputs`, not
# `inputs_embeds`. Single modality (audio) so execute_mm_encoder
# preserves request order; use its return value directly.
# No need to store in encoder_cache: cross-attention K/V are written
# to the KV cache on the first step; decode steps use the cache.
self
.
encoder_outputs
=
self
.
encoder_runner
.
execute_mm_encoder
(
mm_kwargs
)
else
:
# Decode steps: encoder K/V are in cross-attention KV cache.
self
.
encoder_outputs
=
[]
return
None
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
Any
]:
model_inputs
=
{
"encoder_outputs"
:
self
.
encoder_outputs
}
self
.
encoder_outputs
=
[]
return
model_inputs
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
Any
]:
return
{
"encoder_outputs"
:
[]}
def
prepare_attn
(
self
,
input_batch
:
InputBatch
,
cudagraph_mode
:
CUDAGraphMode
,
block_tables
:
tuple
[
torch
.
Tensor
,
...],
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
for_capture
:
bool
=
False
,
)
->
dict
[
str
,
Any
]:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
num_reqs
=
input_batch
.
num_reqs_after_padding
num_tokens
=
input_batch
.
num_tokens_after_padding
else
:
num_reqs
=
input_batch
.
num_reqs
num_tokens
=
input_batch
.
num_tokens
encoder_seq_lens
=
self
.
_get_encoder_seq_lens
(
input_batch
.
req_ids
,
attn_groups
,
for_capture
)
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
)
return
attn_metadata
def
_get_encoder_seq_lens
(
self
,
req_ids
:
list
[
str
],
attn_groups
:
list
[
list
[
AttentionGroup
]],
for_capture
:
bool
,
)
->
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]:
num_reqs
=
len
(
req_ids
)
encoder_seq_lens_np
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
if
not
for_capture
:
# During normal execution, use actual encoder lengths.
for
i
,
req_id
in
enumerate
(
req_ids
):
mm_features
=
self
.
encoder_cache
.
mm_features
.
get
(
req_id
,
[])
encoder_seq_lens_np
[
i
]
=
sum
(
feature
.
mm_position
.
get_num_embeds
()
for
feature
in
mm_features
)
else
:
# During CUDA graph capture, use max encoder length so max_seqlen_k
# is captured with the correct value for cross-attention.
encoder_seq_lens_np
[:]
=
self
.
max_encoder_len
self
.
encoder_seq_lens_gpu
[:
num_reqs
].
copy_
(
torch
.
from_numpy
(
encoder_seq_lens_np
),
non_blocking
=
True
)
self
.
encoder_seq_lens_gpu
[
num_reqs
:].
fill_
(
0
)
encoder_seq_lens_gpu
=
self
.
encoder_seq_lens_gpu
[:
num_reqs
]
seq_lens_by_group
:
dict
[
int
,
tuple
[
torch
.
Tensor
,
np
.
ndarray
]]
=
{}
for
kv_cache_group_idx
,
groups
in
enumerate
(
attn_groups
):
has_cross_attn
=
any
(
isinstance
(
attn_group
.
kv_cache_spec
,
CrossAttentionSpec
)
for
attn_group
in
groups
)
if
has_cross_attn
:
seq_lens_by_group
[
kv_cache_group_idx
]
=
(
encoder_seq_lens_gpu
,
encoder_seq_lens_np
,
)
return
seq_lens_by_group
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment