Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
10a5f4d5
Unverified
Commit
10a5f4d5
authored
Mar 09, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 09, 2026
Browse files
[Model Runner V2] Use NamedTuple for `execute_model_state` (#35930)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
fe0c085c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
34 deletions
+38
-34
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+38
-34
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
10a5f4d5
...
@@ -21,6 +21,7 @@ import functools
...
@@ -21,6 +21,7 @@ import functools
import
gc
import
gc
import
time
import
time
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
Any
,
NamedTuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -44,7 +45,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
...
@@ -44,7 +45,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.outputs
import
DraftTokenIds
,
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
,
AsyncPoolingOutput
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
,
AsyncPoolingOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
from
vllm.v1.worker.gpu.attn_utils
import
(
...
@@ -213,7 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -213,7 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
pooling_runner
:
PoolingRunner
|
None
=
None
self
.
pooling_runner
:
PoolingRunner
|
None
=
None
# For transferring state from execute_model to subsequent sample_tokens call.
# For transferring state from execute_model to subsequent sample_tokens call.
self
.
execute_model_state
:
tupl
e
|
None
=
None
self
.
execute_model_state
:
ExecuteModelStat
e
|
None
=
None
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
...
@@ -375,16 +376,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -375,16 +376,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
None
,
None
return
None
,
None
assert
self
.
execute_model_state
is
not
None
assert
self
.
execute_model_state
is
not
None
(
input_batch
=
self
.
execute_model_state
.
input_batch
input_batch
,
attn_metadata
=
self
.
execute_model_state
.
attn_metadata
model_inputs
,
slot_mappings_by_layer
=
self
.
execute_model_state
.
slot_mappings_by_layer
attn_metadata
,
hidden_states
=
self
.
execute_model_state
.
hidden_states
slot_mappings_by_layer
,
aux_hidden_states
=
self
.
execute_model_state
.
aux_hidden_states
hidden_states
,
num_tokens_across_dp
=
self
.
execute_model_state
.
num_tokens_across_dp
aux_hidden_states
,
kv_connector_output
,
num_tokens_across_dp
,
)
=
self
.
execute_model_state
self
.
execute_model_state
=
None
self
.
execute_model_state
=
None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
# dummy run the eagle speculator's propose to ensure DP/EP sync.
...
@@ -989,15 +986,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -989,15 +986,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
=
None
aux_hidden_states
=
None
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
self
.
execute_model_state
=
(
self
.
execute_model_state
=
ExecuteModelState
(
input_batch
,
input_batch
=
input_batch
,
model_inputs
,
attn_metadata
=
attn_metadata
,
attn_metadata
,
slot_mappings_by_layer
=
slot_mappings_by_layer
,
slot_mappings_by_layer
,
hidden_states
=
hidden_states
,
hidden_states
,
aux_hidden_states
=
aux_hidden_states
,
aux_hidden_states
,
kv_connector_output
=
kv_connector_output
,
kv_connector_output
,
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
,
)
)
if
not
self
.
is_last_pp_rank
:
if
not
self
.
is_last_pp_rank
:
...
@@ -1016,16 +1012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1016,16 +1012,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
execute_model_state
is
None
:
if
self
.
execute_model_state
is
None
:
# The prior execute_model call must have failed.
# The prior execute_model call must have failed.
return
None
return
None
(
input_batch
,
input_batch
=
self
.
execute_model_state
.
input_batch
model_inputs
,
attn_metadata
=
self
.
execute_model_state
.
attn_metadata
attn_metadata
,
slot_mappings_by_layer
=
self
.
execute_model_state
.
slot_mappings_by_layer
slot_mappings_by_layer
,
hidden_states
=
self
.
execute_model_state
.
hidden_states
hidden_states
,
aux_hidden_states
=
self
.
execute_model_state
.
aux_hidden_states
aux_hidden_states
,
kv_connector_output
=
self
.
execute_model_state
.
kv_connector_output
kv_connector_output
,
num_tokens_across_dp
=
self
.
execute_model_state
.
num_tokens_across_dp
num_tokens_across_dp
,
)
=
self
.
execute_model_state
self
.
execute_model_state
=
None
self
.
execute_model_state
=
None
if
not
self
.
is_last_pp_rank
:
if
not
self
.
is_last_pp_rank
:
...
@@ -1116,9 +1110,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1116,9 +1110,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The prior execute_model call must have failed.
# The prior execute_model call must have failed.
return
None
return
None
input_batch
,
_
,
_
,
_
,
hid
de
n
_state
s
,
_
,
kv_connector_output
,
_
=
(
input_batch
=
self
.
execute_mo
de
l
_state
.
input_batch
self
.
execute_model_state
hidden_states
=
self
.
execute_model_state
.
hidden_states
)
kv_connector_output
=
self
.
execute_model_state
.
kv_connector_output
self
.
execute_model_state
=
None
self
.
execute_model_state
=
None
if
not
self
.
is_last_pp_rank
:
if
not
self
.
is_last_pp_rank
:
...
@@ -1164,3 +1158,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1164,3 +1158,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
np
.
minimum
(
np
.
minimum
(
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
)
)
class
ExecuteModelState
(
NamedTuple
):
input_batch
:
InputBatch
attn_metadata
:
dict
[
str
,
Any
]
|
None
slot_mappings_by_layer
:
dict
[
str
,
torch
.
Tensor
]
|
None
hidden_states
:
torch
.
Tensor
|
IntermediateTensors
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
kv_connector_output
:
KVConnectorOutput
|
None
num_tokens_across_dp
:
torch
.
Tensor
|
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment