Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
74c583bc
Unverified
Commit
74c583bc
authored
Jan 19, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Jan 19, 2026
Browse files
[Core] Whisper support `torch.compile` (#30385)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
c0a350ca
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
1 deletion
+27
-1
tests/entrypoints/openai/correctness/test_transcription_api_correctness.py
.../openai/correctness/test_transcription_api_correctness.py
+3
-1
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+7
-0
vllm/forward_context.py
vllm/forward_context.py
+7
-0
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+2
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-0
No files found.
tests/entrypoints/openai/correctness/test_transcription_api_correctness.py
View file @
74c583bc
...
...
@@ -156,7 +156,9 @@ def test_wer_correctness(
model_name
,
dataset_repo
,
expected_wer
,
n_examples
=-
1
,
max_concurrent_request
=
None
):
# TODO refactor to use `ASRDataset`
with
RemoteOpenAIServer
(
model_name
,
[
"--enforce-eager"
])
as
remote_server
:
with
RemoteOpenAIServer
(
model_name
,
[
"--enforce-eager"
],
max_wait_seconds
=
480
)
as
remote_server
:
dataset
=
load_hf_dataset
(
dataset_repo
)
if
not
max_concurrent_request
:
...
...
vllm/compilation/decorators.py
View file @
74c583bc
...
...
@@ -25,6 +25,7 @@ from vllm.config import (
set_current_vllm_config
,
)
from
vllm.config.compilation
import
DynamicShapesType
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
...
...
@@ -388,6 +389,12 @@ def _support_torch_compile(
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
# enc-dec models where tensor shapes/types vary across invocations, preventing
# the capture of a single computational graph.
if
is_forward_context_available
()
and
get_forward_context
().
skip_compiled
:
return
self
.
forward
(
*
args
,
**
kwargs
)
# if aot_compiled_fn is set, call it with partition wrapper context.
# The partition wrapper must be active at runtime for CUDA graph
# capture to work correctly with inductor graph partitioning.
...
...
vllm/forward_context.py
View file @
74c583bc
...
...
@@ -207,6 +207,9 @@ class ForwardContext:
ubatch_slices
:
UBatchSlices
|
None
=
None
# If True, bypass the compiled model call, e.g. by using .forward() directly
skip_compiled
:
bool
=
False
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
def
__post_init__
(
self
):
...
...
@@ -240,6 +243,7 @@ def create_forward_context(
batch_descriptor
:
BatchDescriptor
|
None
=
None
,
ubatch_slices
:
UBatchSlices
|
None
=
None
,
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
):
return
ForwardContext
(
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
...
...
@@ -249,6 +253,7 @@ def create_forward_context(
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
batch_descriptor
,
ubatch_slices
=
ubatch_slices
,
skip_compiled
=
skip_compiled
,
additional_kwargs
=
additional_kwargs
or
{},
)
...
...
@@ -278,6 +283,7 @@ def set_forward_context(
cudagraph_runtime_mode
:
CUDAGraphMode
=
CUDAGraphMode
.
NONE
,
batch_descriptor
:
BatchDescriptor
|
None
=
None
,
ubatch_slices
:
UBatchSlices
|
None
=
None
,
skip_compiled
:
bool
=
False
,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
...
...
@@ -336,6 +342,7 @@ def set_forward_context(
batch_descriptor
,
ubatch_slices
,
additional_kwargs
,
skip_compiled
,
)
try
:
...
...
vllm/model_executor/models/whisper.py
View file @
74c583bc
...
...
@@ -19,6 +19,7 @@ from transformers import (
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module):
return
self
.
forward_layers
(
hidden_states
)
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
-
1
})
class
WhisperDecoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
74c583bc
...
...
@@ -3268,6 +3268,13 @@ class GPUModelRunner(
# Mark KV scales as calculated after the first forward pass
self
.
calculate_kv_scales
=
False
# Encoder-decoder models can only compile the pure decode steps where no
# encoder inputs are present. Use eager for the first pass.
num_encoder_reqs
=
len
(
scheduler_output
.
scheduled_encoder_inputs
)
has_encoder_input
=
(
self
.
model_config
.
is_encoder_decoder
and
num_encoder_reqs
>
0
)
# Run the model.
# Use persistent buffers for CUDA graphs.
with
(
...
...
@@ -3279,6 +3286,7 @@ class GPUModelRunner(
cudagraph_runtime_mode
=
cudagraph_mode
,
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices_padded
,
skip_compiled
=
has_encoder_input
,
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment