Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c66aa48e
Unverified
Commit
c66aa48e
authored
Feb 26, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 26, 2026
Browse files
[Model Runner V2] Add model states [1/N] (#35350)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
b6d5a172
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
119 additions
and
101 deletions
+119
-101
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+23
-36
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+0
-3
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+22
-62
vllm/v1/worker/gpu/model_states.py
vllm/v1/worker/gpu/model_states.py
+74
-0
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
c66aa48e
...
...
@@ -22,6 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import (
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.v1.worker.gpu.model_states
import
ModelState
from
vllm.v1.worker.utils
import
AttentionGroup
...
...
@@ -29,13 +30,11 @@ class CudaGraphManager:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
uses_mrope
:
bool
,
use_aux_hidden_state_outputs
:
bool
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
uses_mrope
=
uses_mrope
self
.
use_aux_hidden_state_outputs
=
use_aux_hidden_state_outputs
self
.
device
=
device
...
...
@@ -88,8 +87,8 @@ class CudaGraphManager:
num_tokens
:
int
,
capture_cg_mode
:
CUDAGraphMode
,
model
:
nn
.
Module
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
mrope_positions
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
...
...
@@ -113,13 +112,18 @@ class CudaGraphManager:
)
else
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
if
self
.
uses_mrope
:
assert
mrope_positions
is
not
None
positions
=
mrope_positions
[:,
:
num_tokens
]
if
inputs_embeds
is
not
None
:
inputs_embeds
=
inputs_embeds
[:
num_tokens
]
model_inputs
=
{
"input_ids"
:
input_buffers
.
input_ids
[:
num_tokens
],
"positions"
:
input_buffers
.
positions
[:
num_tokens
],
"inputs_embeds"
:
(
inputs_embeds
[:
num_tokens
]
if
inputs_embeds
is
not
None
else
None
),
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**
model_state
.
prepare_dummy_inputs
(
num_reqs
,
num_tokens
),
}
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
...
...
@@ -143,11 +147,7 @@ class CudaGraphManager:
num_tokens_across_dp
=
num_tokens_across_dp
,
slot_mapping
=
slot_mappings
,
):
model_output
=
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
model_output
=
model
(
**
model_inputs
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
...
...
@@ -164,9 +164,7 @@ class CudaGraphManager:
num_tokens
=
num_tokens
,
num_reqs
=
num_reqs
,
model
=
model
,
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
model_inputs
=
model_inputs
,
num_tokens_across_dp
=
num_tokens_across_dp
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings
,
...
...
@@ -178,9 +176,7 @@ class CudaGraphManager:
num_tokens
:
int
,
num_reqs
:
int
,
model
:
nn
.
Module
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
,
model_inputs
:
dict
[
str
,
torch
.
Tensor
|
None
],
num_tokens_across_dp
:
torch
.
Tensor
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
...
...
@@ -206,11 +202,8 @@ class CudaGraphManager:
),
torch
.
cuda
.
graph
(
graph
,
self
.
pool
),
):
model_output
=
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
model_output
=
model
(
**
model_inputs
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
...
...
@@ -235,9 +228,7 @@ class CudaGraphManager:
num_tokens
:
int
,
num_reqs
:
int
,
model
:
nn
.
Module
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
,
model_inputs
:
dict
[
str
,
torch
.
Tensor
|
None
],
num_tokens_across_dp
:
torch
.
Tensor
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
...
...
@@ -256,18 +247,14 @@ class CudaGraphManager:
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
slot_mappings
,
):
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
model
(
**
model_inputs
)
@
torch
.
inference_mode
()
def
capture
(
self
,
model
:
nn
.
Module
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
mrope_positions
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
...
...
@@ -278,8 +265,8 @@ class CudaGraphManager:
device
=
self
.
device
,
capture_fn
=
self
.
capture_graph
,
model
=
model
,
model_state
=
model_state
,
input_buffers
=
input_buffers
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
block_tables
,
attn_groups
=
attn_groups
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
c66aa48e
...
...
@@ -65,8 +65,6 @@ class InputBatch:
input_ids
:
torch
.
Tensor
# [num_tokens_after_padding]
positions
:
torch
.
Tensor
# [3, num_tokens_after_padding]
mrope_positions
:
torch
.
Tensor
|
None
# [num_tokens_after_padding, hidden_size]
inputs_embeds
:
torch
.
Tensor
|
None
...
...
@@ -143,7 +141,6 @@ class InputBatch:
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
positions
=
positions
,
mrope_positions
=
None
,
inputs_embeds
=
None
,
attn_metadata
=
None
,
# type: ignore
slot_mappings
=
None
,
# type: ignore
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
c66aa48e
...
...
@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import (
)
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.m
m.mrope_util
s
import
M
Rope
State
from
vllm.v1.worker.gpu.m
odel_state
s
import
M
odel
State
from
vllm.v1.worker.gpu.pp_utils
import
pp_broadcast
,
pp_receive
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
...
...
@@ -140,14 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
if
self
.
uses_mrope
:
self
.
mrope_states
=
MRopeState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
)
self
.
use_async_scheduling
=
self
.
scheduler_config
.
async_scheduling
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
...
...
@@ -212,7 +204,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
uses_mrope
,
self
.
use_aux_hidden_state_outputs
,
self
.
device
,
)
...
...
@@ -271,6 +262,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculator
is
not
None
:
prepare_communication_buffer_for_model
(
self
.
speculator
)
# Initialize the components that require the model.
self
.
model_state
=
ModelState
(
self
.
vllm_config
,
self
.
model
,
self
.
device
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
...
...
@@ -481,16 +475,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
with
self
.
maybe_setup_dummy_loras
(
self
.
lora_config
):
mrope_positions
=
None
if
self
.
uses_mrope
:
mrope_positions
=
self
.
mrope_states
.
mrope_positions
inputs_embeds
=
None
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
self
.
cudagraph_manager
.
capture
(
model
=
self
.
model
,
model_state
=
self
.
model_state
,
input_buffers
=
self
.
input_buffers
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
self
.
block_tables
,
attn_groups
=
self
.
attn_groups
,
...
...
@@ -554,14 +545,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
# Pre-compute M-RoPE positions for prefill.
if
self
.
uses_mrope
:
self
.
mrope_states
.
init_prefill_mrope_positions
(
req_index
,
self
.
model
,
# type: ignore
new_req_data
.
prefill_token_ids
,
mm_features
=
new_req_data
.
mm_features
,
)
self
.
model_state
.
add_request
(
req_index
,
new_req_data
)
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
...
...
@@ -577,8 +561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
apply_staged_writes
()
self
.
sampler
.
apply_staged_writes
()
if
self
.
uses_mrope
:
self
.
mrope_states
.
apply_staged_writes
()
self
.
model_state
.
apply_staged_writes
()
def
update_requests
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
# Add new blocks for the existing requests.
...
...
@@ -692,15 +675,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
# Prepare M-RoPE positions.
if
self
.
uses_mrope
:
self
.
mrope_states
.
prepare_mrope_positions
(
idx_mapping
,
query_start_loc
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
)
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices
=
combine_sampled_and_draft_tokens
(
...
...
@@ -744,10 +718,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
mrope_positions
=
None
if
self
.
uses_mrope
:
mrope_positions
=
self
.
mrope_states
.
mrope_positions
mrope_positions
=
mrope_positions
[:,
:
num_tokens_after_padding
]
return
InputBatch
(
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
...
...
@@ -764,7 +734,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
None
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings_by_layer
,
...
...
@@ -959,14 +928,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_buffers
=
self
.
input_buffers
,
device
=
self
.
device
,
)
if
self
.
uses_mrope
:
input_batch
.
mrope_positions
=
self
.
mrope_states
.
mrope_positions
[
:,
:
num_tokens_after_padding
]
if
not
skip_attn_for_dummy_run
:
self
.
prepare_dummy_attn_metadata
(
input_batch
)
# FIXME(woosuk): Fix warmup for LoRA.
model_inputs
=
{
"input_ids"
:
input_batch
.
input_ids
,
"positions"
:
input_batch
.
positions
,
"inputs_embeds"
:
input_batch
.
inputs_embeds
,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**
self
.
model_state
.
prepare_inputs
(
input_batch
,
self
.
req_states
),
}
if
not
self
.
is_first_pp_rank
:
# Update for non-first PP ranks.
model_inputs
[
"input_ids"
]
=
None
model_inputs
[
"inputs_embeds"
]
=
None
model_inputs
[
"intermediate_tensors"
]
=
intermediate_tensors
# Run model.
if
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
:
# Use explicit cudagraph replay for FULL mode.
...
...
@@ -983,20 +962,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
=
None
else
:
# For piecewise and eager mode, just call model().
positions
=
input_batch
.
positions
if
self
.
uses_mrope
:
assert
input_batch
.
mrope_positions
is
not
None
positions
=
input_batch
.
mrope_positions
if
self
.
is_first_pp_rank
:
input_ids
=
input_batch
.
input_ids
inputs_embeds
=
input_batch
.
inputs_embeds
assert
intermediate_tensors
is
None
else
:
input_ids
=
None
inputs_embeds
=
None
assert
intermediate_tensors
is
not
None
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
input_batch
.
num_tokens_after_padding
,
has_lora
=
self
.
lora_config
is
not
None
,
...
...
@@ -1012,12 +977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping
=
input_batch
.
slot_mappings
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
model_output
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
)
model_output
=
self
.
model
(
**
model_inputs
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
...
...
vllm/v1/worker/gpu/model_states.py
0 → 100644
View file @
c66aa48e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.states
import
RequestState
class
ModelState
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
device
:
torch
.
device
):
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
model
=
model
self
.
device
=
device
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
if
self
.
uses_mrope
:
self
.
mrope_state
=
MRopeState
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
)
def
add_request
(
self
,
req_index
:
int
,
new_req_data
:
NewRequestData
)
->
None
:
if
self
.
uses_mrope
:
# Pre-compute M-RoPE positions for prefill.
assert
new_req_data
.
prefill_token_ids
is
not
None
self
.
mrope_state
.
init_prefill_mrope_positions
(
req_index
,
self
.
model
,
# type: ignore
new_req_data
.
prefill_token_ids
,
mm_features
=
new_req_data
.
mm_features
,
)
def
apply_staged_writes
(
self
)
->
None
:
if
self
.
uses_mrope
:
self
.
mrope_state
.
apply_staged_writes
()
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
if
not
self
.
uses_mrope
:
# Common case (1D positions).
return
{}
# Prepare M-RoPE positions.
self
.
mrope_state
.
prepare_mrope_positions
(
input_batch
.
idx_mapping
,
input_batch
.
query_start_loc
,
req_states
.
prefill_len
.
gpu
,
req_states
.
num_computed_tokens
.
gpu
,
)
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[
:,
:
input_batch
.
num_tokens_after_padding
]
return
{
"positions"
:
mrope_positions
}
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
if
not
self
.
uses_mrope
:
return
{}
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[:,
:
num_tokens
]
return
{
"positions"
:
mrope_positions
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment