Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0efd9f86
Unverified
Commit
0efd9f86
authored
Dec 11, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Dec 11, 2025
Browse files
[Core] Whisper Enable Encoder Batching (#29421)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
90d6cf92
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
25 deletions
+87
-25
vllm/config/model.py
vllm/config/model.py
+5
-0
vllm/config/vllm.py
vllm/config/vllm.py
+10
-20
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+13
-4
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+53
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+6
-1
No files found.
vllm/config/model.py
View file @
0efd9f86
...
...
@@ -539,6 +539,11 @@ class ModelConfig:
self
.
original_max_model_len
=
self
.
max_model_len
self
.
max_model_len
=
self
.
get_and_verify_max_len
(
self
.
max_model_len
)
if
self
.
is_encoder_decoder
:
self
.
mm_processor_cache_gb
=
0
logger
.
info
(
"Encoder-decoder model detected, disabling mm processor cache."
)
# Init multimodal config if needed
if
self
.
_model_info
.
supports_multimodal
:
if
(
...
...
vllm/config/vllm.py
View file @
0efd9f86
...
...
@@ -750,19 +750,9 @@ class VllmConfig:
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self
.
_set_compile_ranges
()
if
self
.
model_config
and
self
.
model_config
.
is_encoder_decoder
:
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
self
.
scheduler_config
.
max_num_encoder_input_tokens
=
(
MULTIMODAL_REGISTRY
.
get_encdec_max_encoder_len
(
self
.
model_config
)
)
logger
.
debug
(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)"
,
self
.
scheduler_config
.
max_num_encoder_input_tokens
,
)
if
(
self
.
model_config
.
architecture
==
"WhisperForConditionalGeneration"
self
.
model_config
and
self
.
model_config
.
architecture
==
"WhisperForConditionalGeneration"
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
):
logger
.
warning
(
...
...
vllm/model_executor/models/whisper.py
View file @
0efd9f86
...
...
@@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
def
forward
(
self
,
input_features
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]):
hidden_states
=
[]
input_is_batched
=
False
for
features
in
input_features
:
embeds
=
nn
.
functional
.
gelu
(
self
.
conv1
(
features
))
embeds
=
nn
.
functional
.
gelu
(
self
.
conv2
(
embeds
))
...
...
@@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
embeds
.
dtype
)
hidden_states
.
append
(
embeds
)
input_is_batched
=
embeds
.
ndim
>
2
# Input to MHA must be B x T x D
if
input_is_batched
:
# Models using WhisperEncoder may handle batching internally.
hidden_states
=
torch
.
cat
(
hidden_states
)
else
:
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
0
)
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
)
...
...
@@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
positions
:
torch
.
Tensor
,
encoder_outputs
:
list
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
assert
len
(
encoder_outputs
)
in
(
0
,
1
)
enc_states
=
encoder_outputs
[
0
]
if
len
(
encoder_outputs
)
==
1
else
None
enc_states
=
torch
.
cat
(
encoder_outputs
,
dim
=
0
)
if
len
(
encoder_outputs
)
else
None
decoder_outputs
=
self
.
decoder
(
input_ids
=
input_ids
,
positions
=
positions
,
...
...
@@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
# Required as part of SupportsMultiModal interface.
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
return
[
self
.
model
.
get_encoder_outputs
(
audio_input
[
"input_features"
])]
# Split concatenated encoder outputs into one tensor per audio input
enc_output
=
self
.
model
.
get_encoder_outputs
(
audio_input
[
"input_features"
])
# The assumption is we can only process whole mm items (audios)
return
enc_output
.
unbind
(
dim
=
0
)
def
embed_input_ids
(
self
,
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
0efd9f86
...
...
@@ -341,3 +341,56 @@ def compute_mm_encoder_budget(
)
return
encoder_compute_budget
,
encoder_cache_size
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
# use the manager for scheduling purposes. Encoder-decoder models will eventually
# utilize the cache and this class will fold into EncoderCacheManager, as
# differences with MM models shrink.
class
EncoderDecoderCacheManager
(
EncoderCacheManager
):
def
__init__
(
self
,
cache_size
:
int
):
self
.
cache_size
=
cache_size
self
.
num_free_slots
=
cache_size
self
.
freed
:
list
[
str
]
=
[]
def
check_and_update_cache
(
self
,
request
:
Request
,
input_id
:
int
)
->
bool
:
return
False
def
can_allocate
(
self
,
request
:
Request
,
input_id
:
int
,
encoder_compute_budget
:
int
,
num_tokens_to_schedule
:
int
,
)
->
bool
:
num_tokens
=
request
.
get_num_encoder_tokens
(
input_id
)
# Not enough compute budget
if
num_tokens
>
encoder_compute_budget
:
return
False
num_tokens
+=
num_tokens_to_schedule
# Enough free slots
return
num_tokens
<=
self
.
num_free_slots
def
allocate
(
self
,
request
:
Request
,
input_id
:
int
)
->
None
:
num_encoder_tokens
=
request
.
get_num_encoder_tokens
(
input_id
)
self
.
num_free_slots
-=
num_encoder_tokens
mm_hash
=
request
.
mm_features
[
input_id
].
identifier
self
.
freed
.
append
(
mm_hash
)
def
free
(
self
,
request
:
Request
)
->
None
:
for
input_id
in
range
(
len
(
request
.
mm_features
)):
self
.
free_encoder_input
(
request
,
input_id
)
def
get_cached_input_ids
(
self
,
request
:
Request
)
->
set
[
int
]:
return
set
(
range
(
len
(
request
.
mm_features
)))
def
get_freed_mm_hashes
(
self
)
->
list
[
str
]:
freed
=
self
.
freed
self
.
freed
=
[]
return
freed
def
free_encoder_input
(
self
,
request
:
Request
,
input_id
:
int
)
->
None
:
num_tokens
=
request
.
get_num_encoder_tokens
(
input_id
)
self
.
num_free_slots
+=
num_tokens
vllm/v1/core/sched/scheduler.py
View file @
0efd9f86
...
...
@@ -27,6 +27,7 @@ from vllm.logger import init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
EncoderDecoderCacheManager
,
compute_encoder_budget
,
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
...
...
@@ -181,7 +182,11 @@ class Scheduler(SchedulerInterface):
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
self
.
encoder_cache_manager
=
(
EncoderDecoderCacheManager
(
cache_size
=
encoder_cache_size
)
if
self
.
is_encoder_decoder
else
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
)
speculative_config
=
vllm_config
.
speculative_config
self
.
use_eagle
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment