Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d66502e
Unverified
Commit
3d66502e
authored
Feb 26, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 26, 2026
Browse files
[Model Runner V2] Prepare attn metadata in ModelState [2/N] (#35383)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
c66aa48e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
110 additions
and
92 deletions
+110
-92
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+3
-8
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+72
-82
vllm/v1/worker/gpu/model_states.py
vllm/v1/worker/gpu/model_states.py
+31
-0
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+4
-2
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
3d66502e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -60,6 +59,8 @@ class InputBatch:
...
@@ -60,6 +59,8 @@ class InputBatch:
query_start_loc_np
:
np
.
ndarray
query_start_loc_np
:
np
.
ndarray
# [num_reqs]
# [num_reqs]
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
# [num_reqs]
dcp_local_seq_lens
:
torch
.
Tensor
|
None
# [num_tokens_after_padding]
# [num_tokens_after_padding]
input_ids
:
torch
.
Tensor
input_ids
:
torch
.
Tensor
...
@@ -68,11 +69,6 @@ class InputBatch:
...
@@ -68,11 +69,6 @@ class InputBatch:
# [num_tokens_after_padding, hidden_size]
# [num_tokens_after_padding, hidden_size]
inputs_embeds
:
torch
.
Tensor
|
None
inputs_embeds
:
torch
.
Tensor
|
None
# layer_name -> Metadata
attn_metadata
:
dict
[
str
,
Any
]
# layer_name -> slot_mapping
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
# [total_num_logits]
# [total_num_logits]
logits_indices
:
torch
.
Tensor
logits_indices
:
torch
.
Tensor
# [num_reqs + 1]
# [num_reqs + 1]
...
@@ -139,11 +135,10 @@ class InputBatch:
...
@@ -139,11 +135,10 @@ class InputBatch:
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
dcp_local_seq_lens
=
None
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
attn_metadata
=
None
,
# type: ignore
slot_mappings
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
cu_num_logits_np
=
cu_num_logits_np
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
3d66502e
...
@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
...
@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_slot_mappings_by_layer
,
build_slot_mappings_by_layer
,
get_kv_cache_spec
,
get_kv_cache_spec
,
init_attn_backend
,
init_attn_backend
,
...
@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
def
prepare_dummy_attn_metadata
(
self
,
input_batch
:
InputBatch
)
->
None
:
block_tables
=
self
.
block_tables
.
get_dummy_block_tables
(
input_batch
.
num_reqs
)
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
input_batch
.
num_tokens
)
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
self
.
kv_cache_config
)
attn_metadata
=
build_attn_metadata
(
attn_groups
=
self
.
attn_groups
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
),
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
(),
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
,
)
input_batch
.
attn_metadata
=
attn_metadata
input_batch
.
slot_mappings
=
slot_mappings_by_layer
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
_dummy_run
(
def
_dummy_run
(
self
,
num_tokens
:
int
,
*
args
,
skip_attn
:
bool
=
True
,
**
kwargs
self
,
num_tokens
:
int
,
*
args
,
skip_attn
:
bool
=
True
,
**
kwargs
...
@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
None
,
None
return
None
,
None
assert
self
.
execute_model_state
is
not
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
_
,
input_batch
,
_
=
self
.
execute_model_state
input_batch
,
_
,
_
,
_
,
hidden_states
,
_
,
_
=
self
.
execute_model_state
self
.
execute_model_state
=
None
self
.
execute_model_state
=
None
assert
hidden_states
is
not
None
# Last PP rank always has hidden_states
assert
hidden_states
is
not
None
# Last PP rank always has hidden_states
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
...
@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
encoder_runner
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
self
.
encoder_runner
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
self
.
model_state
.
add_request
(
req_index
,
new_req_data
)
self
.
model_state
.
add_request
(
req_index
,
new_req_data
)
self
.
block_tables
.
append_block_ids
(
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
)
...
@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
)
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
# Get query_start_loc.
# Get query_start_loc.
query_start_loc_np
=
np
.
empty
(
self
.
max_num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
=
np
.
empty
(
self
.
max_num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
[
0
]
=
0
query_start_loc_np
[
0
]
=
0
...
@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
async_copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
)
async_copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
max_query_len
=
num_scheduled_tokens
.
max
().
item
()
# Get prefill tokens if any.
# Get prefill tokens if any.
if
self
.
req_states
.
any_prefills
(
idx_mapping_np
):
if
self
.
req_states
.
any_prefills
(
idx_mapping_np
):
...
@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
None
if
self
.
use_dcp
:
if
self
.
use_dcp
:
# Prepare dcp local seq_lens.
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens
(
prepare_dcp_local_seq_lens
(
...
@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
dcp_rank
,
self
.
dcp_rank
,
self
.
cp_interleave
,
self
.
cp_interleave
,
)
)
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
# Some input token ids are directly read from the last sampled tokens
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
# and draft tokens. Also, get the logits indices to sample tokens from.
...
@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits
,
total_num_logits
,
)
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
self
.
input_buffers
.
positions
[:
num_tokens
],
)
# Layer name -> slot mapping.
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
self
.
kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata
=
build_attn_metadata
(
attn_groups
=
self
.
attn_groups
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
self
.
input_buffers
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
]
return
InputBatch
(
return
InputBatch
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
...
@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
positions
=
positions
,
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
],
inputs_embeds
=
None
,
inputs_embeds
=
None
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings_by_layer
,
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
cu_num_logits_np
=
cu_num_logits_np
,
has_structured_output_reqs
=
scheduler_output
.
has_structured_output_requests
,
has_structured_output_reqs
=
scheduler_output
.
has_structured_output_requests
,
)
)
def
prepare_attn
(
self
,
input_batch
:
InputBatch
)
->
tuple
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
input_batch
.
idx_mapping
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
input_batch
.
idx_mapping
,
input_batch
.
query_start_loc
,
input_batch
.
positions
,
)
return
block_tables
,
slot_mappings
def
prepare_dummy_attn
(
self
,
input_batch
:
InputBatch
)
->
tuple
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
block_tables
=
self
.
block_tables
.
get_dummy_block_tables
(
input_batch
.
num_reqs
)
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
input_batch
.
num_tokens
)
return
block_tables
,
slot_mappings
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
get_mm_embeddings
(
def
get_mm_embeddings
(
self
,
self
,
...
@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch
=
self
.
prepare_inputs
(
input_batch
=
self
.
prepare_inputs
(
scheduler_output
,
num_tokens_after_padding
scheduler_output
,
num_tokens_after_padding
)
)
block_tables
,
slot_mappings
=
self
.
prepare_attn
(
input_batch
)
if
self
.
lora_config
:
if
self
.
lora_config
:
# Activate LoRA adapters.
# Activate LoRA adapters.
lora_inputs
=
self
.
lora_state
.
make_lora_inputs
(
lora_inputs
=
self
.
lora_state
.
make_lora_inputs
(
...
@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
if
not
skip_attn_for_dummy_run
:
if
not
skip_attn_for_dummy_run
:
self
.
prepare_dummy_attn_metadata
(
input_batch
)
block_tables
,
slot_mappings
=
self
.
prepare_dummy_attn
(
input_batch
)
else
:
block_tables
=
None
slot_mappings
=
None
# FIXME(woosuk): Fix warmup for LoRA.
# FIXME(woosuk): Fix warmup for LoRA.
attn_metadata
=
None
slot_mappings_by_layer
=
None
if
not
(
dummy_run
and
skip_attn_for_dummy_run
):
assert
slot_mappings
is
not
None
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
slot_mappings
,
self
.
kv_cache_config
)
assert
block_tables
is
not
None
attn_metadata
=
self
.
model_state
.
prepare_attn
(
input_batch
,
block_tables
,
slot_mappings
,
self
.
attn_groups
,
self
.
kv_cache_config
,
)
model_inputs
=
{
model_inputs
=
{
"input_ids"
:
input_batch
.
input_ids
,
"input_ids"
:
input_batch
.
input_ids
,
"positions"
:
input_batch
.
positions
,
"positions"
:
input_batch
.
positions
,
...
@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
with
set_forward_context
(
with
set_forward_context
(
input_batch
.
attn_metadata
,
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
num_tokens_across_dp
=
num_tokens_across_dp
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
input_batch
.
slot_mappings
,
slot_mapping
=
slot_mappings
_by_layer
,
):
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
model_output
=
self
.
model
(
**
model_inputs
)
model_output
=
self
.
model
(
**
model_inputs
)
...
@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states
=
None
aux_hidden_states
=
None
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
self
.
execute_model_state
=
(
input_batch
,
model_inputs
,
attn_metadata
,
slot_mappings_by_layer
,
hidden_states
,
aux_hidden_states
,
kv_connector_output
,
)
if
not
self
.
is_last_pp_rank
:
if
not
self
.
is_last_pp_rank
:
# Non-last PP rank: return IntermediateTensors for sending.
# Non-last PP rank: return IntermediateTensors for sending.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
hidden_states
.
kv_connector_output
=
kv_connector_output
self
.
execute_model_state
=
(
None
,
None
,
input_batch
,
kv_connector_output
)
return
hidden_states
return
hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
self
.
execute_model_state
=
(
hidden_states
,
aux_hidden_states
,
input_batch
,
kv_connector_output
,
)
return
None
return
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
execute_model_state
is
None
:
if
self
.
execute_model_state
is
None
:
# The prior execute_model call must have failed.
# The prior execute_model call must have failed.
return
None
return
None
hidden_states
,
aux_hidden_states
,
input_batch
,
kv_connector_output
=
(
(
self
.
execute_model_state
input_batch
,
)
model_inputs
,
attn_metadata
,
slot_mappings_by_layer
,
hidden_states
,
aux_hidden_states
,
kv_connector_output
,
)
=
self
.
execute_model_state
self
.
execute_model_state
=
None
self
.
execute_model_state
=
None
if
not
self
.
is_last_pp_rank
:
if
not
self
.
is_last_pp_rank
:
...
@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
speculator
is
not
None
:
if
self
.
speculator
is
not
None
:
draft_tokens
=
self
.
speculator
.
propose
(
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
input_batch
,
attn_metadata
,
slot_mappings_by_layer
,
hidden_states
,
hidden_states
,
aux_hidden_states
,
aux_hidden_states
,
num_sampled
,
num_sampled
,
...
...
vllm/v1/worker/gpu/model_states.py
View file @
3d66502e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.utils
import
AttentionGroup
class
ModelState
:
class
ModelState
:
...
@@ -72,3 +77,29 @@ class ModelState:
...
@@ -72,3 +77,29 @@ class ModelState:
return
{}
return
{}
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[:,
:
num_tokens
]
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[:,
:
num_tokens
]
return
{
"positions"
:
mrope_positions
}
return
{
"positions"
:
mrope_positions
}
def
prepare_attn
(
self
,
input_batch
:
InputBatch
,
block_tables
:
tuple
[
torch
.
Tensor
,
...],
slot_mappings
:
torch
.
Tensor
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
seq_lens
=
input_batch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
dcp_local_seq_lens
=
input_batch
.
dcp_local_seq_lens
,
)
return
attn_metadata
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
View file @
3d66502e
...
@@ -182,6 +182,8 @@ class EagleSpeculator:
...
@@ -182,6 +182,8 @@ class EagleSpeculator:
def
propose
(
def
propose
(
self
,
self
,
input_batch
:
InputBatch
,
input_batch
:
InputBatch
,
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
# [num_tokens, hidden_size]
# [num_tokens, hidden_size]
last_hidden_states
:
torch
.
Tensor
,
last_hidden_states
:
torch
.
Tensor
,
# num_layers x [num_tokens, hidden_size]
# num_layers x [num_tokens, hidden_size]
...
@@ -229,8 +231,8 @@ class EagleSpeculator:
...
@@ -229,8 +231,8 @@ class EagleSpeculator:
# TODO(woosuk): Support CUDA graph for prefill.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
last_hidden_states
,
hidden_states
=
self
.
run_model
(
num_tokens
,
num_tokens
,
input_batch
.
attn_metadata
,
attn_metadata
,
input_batch
.
slot_mappings
,
slot_mappings
,
num_tokens_across_dp
=
None
,
# FIXME
num_tokens_across_dp
=
None
,
# FIXME
)
)
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment