Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2cbf9656
Unverified
Commit
2cbf9656
authored
Feb 21, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 21, 2026
Browse files
[Model Runner V2] Enable CUDA graph for Eagle3 (#35040)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
30132cd1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
9 deletions
+44
-9
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+39
-7
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+5
-2
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
2cbf9656
...
...
@@ -25,10 +25,17 @@ from vllm.v1.worker.utils import AttentionGroup
class
CudaGraphManager
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
uses_mrope
:
bool
,
device
:
torch
.
device
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
uses_mrope
:
bool
,
use_aux_hidden_state_outputs
:
bool
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
uses_mrope
=
uses_mrope
self
.
use_aux_hidden_state_outputs
=
use_aux_hidden_state_outputs
self
.
device
=
device
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
...
...
@@ -63,6 +70,7 @@ class CudaGraphManager:
if
self
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
:
self
.
pool
=
torch
.
cuda
.
graph_pool_handle
()
self
.
hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
aux_hidden_states
:
list
[
torch
.
Tensor
]
=
[]
def
needs_capture
(
self
)
->
bool
:
return
len
(
self
.
cudagraph_sizes
)
>
0
...
...
@@ -134,13 +142,22 @@ class CudaGraphManager:
num_tokens_across_dp
=
num_tokens_across_dp
,
slot_mapping
=
slot_mappings
,
):
hidden_states
=
model
(
model_output
=
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
# Allocate output buffers if not already done.
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
self
.
use_aux_hidden_state_outputs
and
not
self
.
aux_hidden_states
:
self
.
aux_hidden_states
=
[
torch
.
empty_like
(
x
)
for
x
in
aux_hidden_states
]
capture_fn
(
num_tokens
=
num_tokens
,
...
...
@@ -183,13 +200,23 @@ class CudaGraphManager:
),
torch
.
cuda
.
graph
(
graph
,
self
.
pool
),
):
hidden_states
=
model
(
model_output
=
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
# Copy outputs to the output buffers.
assert
self
.
hidden_states
is
not
None
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
if
self
.
use_aux_hidden_state_outputs
:
for
i
,
aux_hidden
in
enumerate
(
aux_hidden_states
):
self
.
aux_hidden_states
[
i
][:
num_tokens
]
=
aux_hidden
self
.
graphs
[
num_tokens
]
=
graph
def
_capture_piecewise_graph
(
...
...
@@ -298,11 +325,16 @@ class CudaGraphManager:
cudagraph_size
=
None
return
cudagraph_mode
,
cudagraph_size
def
run_fullgraph
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
def
run_fullgraph
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
assert
num_tokens
in
self
.
graphs
,
f
"No cudagraph for
{
num_tokens
}
tokens"
self
.
graphs
[
num_tokens
].
replay
()
assert
self
.
hidden_states
is
not
None
return
self
.
hidden_states
[:
num_tokens
]
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
if
not
self
.
use_aux_hidden_state_outputs
:
return
hidden_states
return
hidden_states
,
[
x
[:
num_tokens
]
for
x
in
self
.
aux_hidden_states
]
def
get_cudagraph_sizes
(
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
2cbf9656
...
...
@@ -197,7 +197,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
uses_mrope
,
self
.
device
self
.
vllm_config
,
self
.
uses_mrope
,
self
.
use_aux_hidden_state_outputs
,
self
.
device
,
)
# Structured outputs worker.
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
...
...
@@ -1044,7 +1047,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
,
input_batch
,
kv_connector_output
,
)
)
# type: ignore
return
None
@
torch
.
inference_mode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment