Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d66502e
Unverified
Commit
3d66502e
authored
Feb 26, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 26, 2026
Browse files
[Model Runner V2] Prepare attn metadata in ModelState [2/N] (#35383)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
c66aa48e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
110 additions
and
92 deletions
+110
-92
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+3
-8
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+72
-82
vllm/v1/worker/gpu/model_states.py
vllm/v1/worker/gpu/model_states.py
+31
-0
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+4
-2
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
3d66502e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
import
numpy
as
np
import
torch
...
...
@@ -60,6 +59,8 @@ class InputBatch:
query_start_loc_np
:
np
.
ndarray
# [num_reqs]
seq_lens
:
torch
.
Tensor
# [num_reqs]
dcp_local_seq_lens
:
torch
.
Tensor
|
None
# [num_tokens_after_padding]
input_ids
:
torch
.
Tensor
...
...
@@ -68,11 +69,6 @@ class InputBatch:
# [num_tokens_after_padding, hidden_size]
inputs_embeds
:
torch
.
Tensor
|
None
# layer_name -> Metadata
attn_metadata
:
dict
[
str
,
Any
]
# layer_name -> slot_mapping
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
# [total_num_logits]
logits_indices
:
torch
.
Tensor
# [num_reqs + 1]
...
...
@@ -139,11 +135,10 @@ class InputBatch:
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
dcp_local_seq_lens
=
None
,
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
None
,
attn_metadata
=
None
,
# type: ignore
slot_mappings
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
3d66502e
...
...
@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_slot_mappings_by_layer
,
get_kv_cache_spec
,
init_attn_backend
,
...
...
@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
def
prepare_dummy_attn_metadata
(
self
,
input_batch
:
InputBatch
)
->
None
:
block_tables
=
self
.
block_tables
.
get_dummy_block_tables
(
input_batch
.
num_reqs
)
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
input_batch
.
num_tokens
)
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
self
.
kv_cache_config
)
attn_metadata
=
build_attn_metadata
(
attn_groups
=
self
.
attn_groups
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
),
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
(),
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
,
)
input_batch
.
attn_metadata
=
attn_metadata
input_batch
.
slot_mappings
=
slot_mappings_by_layer
@
torch
.
inference_mode
()
def
_dummy_run
(
self
,
num_tokens
:
int
,
*
args
,
skip_attn
:
bool
=
True
,
**
kwargs
...
...
@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
None
,
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
_
,
input_batch
,
_
=
self
.
execute_model_state
input_batch
,
_
,
_
,
_
,
hidden_states
,
_
,
_
=
self
.
execute_model_state
self
.
execute_model_state
=
None
assert
hidden_states
is
not
None
# Last PP rank always has hidden_states
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
...
...
@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
encoder_runner
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
self
.
model_state
.
add_request
(
req_index
,
new_req_data
)
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
...
...
@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
# Get query_start_loc.
query_start_loc_np
=
np
.
empty
(
self
.
max_num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
[
0
]
=
0
...
...
@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
async_copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
max_query_len
=
num_scheduled_tokens
.
max
().
item
()
# Get prefill tokens if any.
if
self
.
req_states
.
any_prefills
(
idx_mapping_np
):
...
...
@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
None
if
self
.
use_dcp
:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens
(
...
...
@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
dcp_rank
,
self
.
cp_interleave
,
)
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
...
...
@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits
,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
self
.
input_buffers
.
positions
[:
num_tokens
],
)
# Layer name -> slot mapping.
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
self
.
kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata
=
build_attn_metadata
(
attn_groups
=
self
.
attn_groups
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
self
.
input_buffers
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
return
InputBatch
(
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
...
...
@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
positions
=
positions
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
],
inputs_embeds
=
None
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings_by_layer
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
has_structured_output_reqs
=
scheduler_output
.
has_structured_output_requests
,
)
def
prepare_attn
(
self
,
input_batch
:
InputBatch
)
->
tuple
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
input_batch
.
idx_mapping
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
input_batch
.
idx_mapping
,
input_batch
.
query_start_loc
,
input_batch
.
positions
,
)
return
block_tables
,
slot_mappings
def
prepare_dummy_attn
(
self
,
input_batch
:
InputBatch
)
->
tuple
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
block_tables
=
self
.
block_tables
.
get_dummy_block_tables
(
input_batch
.
num_reqs
)
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
input_batch
.
num_tokens
)
return
block_tables
,
slot_mappings
@
torch
.
inference_mode
()
def
get_mm_embeddings
(
self
,
...
...
@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch
=
self
.
prepare_inputs
(
scheduler_output
,
num_tokens_after_padding
)
block_tables
,
slot_mappings
=
self
.
prepare_attn
(
input_batch
)
if
self
.
lora_config
:
# Activate LoRA adapters.
lora_inputs
=
self
.
lora_state
.
make_lora_inputs
(
...
...
@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
)
if
not
skip_attn_for_dummy_run
:
self
.
prepare_dummy_attn_metadata
(
input_batch
)
block_tables
,
slot_mappings
=
self
.
prepare_dummy_attn
(
input_batch
)
else
:
block_tables
=
None
slot_mappings
=
None
# FIXME(woosuk): Fix warmup for LoRA.
attn_metadata
=
None
slot_mappings_by_layer
=
None
if
not
(
dummy_run
and
skip_attn_for_dummy_run
):
assert
slot_mappings
is
not
None
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
self
.
kv_cache_config
)
assert
block_tables
is
not
None
attn_metadata
=
self
.
model_state
.
prepare_attn
(
input_batch
,
block_tables
,
slot_mappings
,
self
.
attn_groups
,
self
.
kv_cache_config
,
)
model_inputs
=
{
"input_ids"
:
input_batch
.
input_ids
,
"positions"
:
input_batch
.
positions
,
...
...
@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
with
set_forward_context
(
input_batch
.
attn_metadata
,
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
input_batch
.
slot_mappings
,
slot_mapping
=
slot_mappings
_by_layer
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
model_output
=
self
.
model
(
**
model_inputs
)
...
...
@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
=
None
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
self
.
execute_model_state
=
(
input_batch
,
model_inputs
,
attn_metadata
,
slot_mappings_by_layer
,
hidden_states
,
aux_hidden_states
,
kv_connector_output
,
)
if
not
self
.
is_last_pp_rank
:
# Non-last PP rank: return IntermediateTensors for sending.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
self
.
execute_model_state
=
(
None
,
None
,
input_batch
,
kv_connector_output
)
return
hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
self
.
execute_model_state
=
(
hidden_states
,
aux_hidden_states
,
input_batch
,
kv_connector_output
,
)
return
None
@
torch
.
inference_mode
()
...
...
@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
execute_model_state
is
None
:
# The prior execute_model call must have failed.
return
None
hidden_states
,
aux_hidden_states
,
input_batch
,
kv_connector_output
=
(
self
.
execute_model_state
)
(
input_batch
,
model_inputs
,
attn_metadata
,
slot_mappings_by_layer
,
hidden_states
,
aux_hidden_states
,
kv_connector_output
,
)
=
self
.
execute_model_state
self
.
execute_model_state
=
None
if
not
self
.
is_last_pp_rank
:
...
...
@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculator
is
not
None
:
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
attn_metadata
,
slot_mappings_by_layer
,
hidden_states
,
aux_hidden_states
,
num_sampled
,
...
...
vllm/v1/worker/gpu/model_states.py
View file @
3d66502e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.utils
import
AttentionGroup
class
ModelState
:
...
...
@@ -72,3 +77,29 @@ class ModelState:
return
{}
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[:,
:
num_tokens
]
return
{
"positions"
:
mrope_positions
}
def
prepare_attn
(
self
,
input_batch
:
InputBatch
,
block_tables
:
tuple
[
torch
.
Tensor
,
...],
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
)
return
attn_metadata
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
View file @
3d66502e
...
...
@@ -182,6 +182,8 @@ class EagleSpeculator:
def
propose
(
self
,
input_batch
:
InputBatch
,
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
# [num_tokens, hidden_size]
last_hidden_states
:
torch
.
Tensor
,
# num_layers x [num_tokens, hidden_size]
...
...
@@ -229,8 +231,8 @@ class EagleSpeculator:
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_tokens
,
input_batch
.
attn_metadata
,
input_batch
.
slot_mappings
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
=
None
,
# FIXME
)
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment