Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a014a0a
Unverified
Commit
1a014a0a
authored
Feb 27, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 27, 2026
Browse files
[Model Runner V2] Move MM encoder to Model States [3/N] (#35564)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
86ac7bcf
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
135 additions
and
111 deletions
+135
-111
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+0
-6
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+0
-3
vllm/v1/worker/gpu/mm/encoder_cache.py
vllm/v1/worker/gpu/mm/encoder_cache.py
+40
-0
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+8
-32
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+34
-64
vllm/v1/worker/gpu/model_states.py
vllm/v1/worker/gpu/model_states.py
+53
-5
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+0
-1
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
1a014a0a
...
...
@@ -89,7 +89,6 @@ class CudaGraphManager:
model
:
nn
.
Module
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
...
...
@@ -116,9 +115,6 @@ class CudaGraphManager:
model_inputs
=
{
"input_ids"
:
input_buffers
.
input_ids
[:
num_tokens
],
"positions"
:
input_buffers
.
positions
[:
num_tokens
],
"inputs_embeds"
:
(
inputs_embeds
[:
num_tokens
]
if
inputs_embeds
is
not
None
else
None
),
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**
model_state
.
prepare_dummy_inputs
(
num_reqs
,
num_tokens
),
...
...
@@ -255,7 +251,6 @@ class CudaGraphManager:
model
:
nn
.
Module
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
...
...
@@ -267,7 +262,6 @@ class CudaGraphManager:
model
=
model
,
model_state
=
model_state
,
input_buffers
=
input_buffers
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
block_tables
,
attn_groups
=
attn_groups
,
kv_cache_config
=
kv_cache_config
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
1a014a0a
...
...
@@ -66,8 +66,6 @@ class InputBatch:
input_ids
:
torch
.
Tensor
# [num_tokens_after_padding]
positions
:
torch
.
Tensor
# [num_tokens_after_padding, hidden_size]
inputs_embeds
:
torch
.
Tensor
|
None
# [total_num_logits]
logits_indices
:
torch
.
Tensor
...
...
@@ -138,7 +136,6 @@ class InputBatch:
dcp_local_seq_lens
=
None
,
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
None
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
...
...
vllm/v1/worker/gpu/mm/encoder_cache.py
0 → 100644
View file @
1a014a0a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
class
EncoderCache
:
def
__init__
(
self
):
# req_id -> MM features
self
.
mm_features
:
dict
[
str
,
list
[
MultiModalFeatureSpec
]]
=
{}
# MM hash -> encoder outputs
self
.
encoder_outputs
:
dict
[
str
,
torch
.
Tensor
]
=
{}
def
add_request
(
self
,
req_id
:
str
,
mm_features
:
list
[
MultiModalFeatureSpec
]
)
->
None
:
self
.
mm_features
[
req_id
]
=
mm_features
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
mm_features
.
pop
(
req_id
,
None
)
def
reset_mm_cache
(
self
)
->
None
:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def
reset_encoder_cache
(
self
)
->
None
:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self
.
encoder_outputs
.
clear
()
def
free_encoder_cache
(
self
,
mm_hash
:
str
)
->
None
:
self
.
encoder_outputs
.
pop
(
mm_hash
,
None
)
vllm/v1/worker/gpu/mm/encoder_runner.py
View file @
1a014a0a
...
...
@@ -4,8 +4,9 @@ import numpy as np
import
torch
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItem
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.v1.worker.gpu.mm.encoder_cache
import
EncoderCache
from
vllm.v1.worker.utils
import
sanity_check_mm_encoder_outputs
...
...
@@ -14,44 +15,19 @@ class EncoderRunner:
self
,
max_num_tokens
:
int
,
hidden_size
:
int
,
encoder_cache
:
EncoderCache
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
self
.
max_num_tokens
=
max_num_tokens
self
.
hidden_size
=
hidden_size
self
.
encoder_cache
=
encoder_cache
self
.
dtype
=
dtype
self
.
device
=
device
self
.
inputs_embeds
=
torch
.
zeros
(
max_num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
req_id_to_mm_features
:
dict
[
str
,
list
[
MultiModalFeatureSpec
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
torch
.
Tensor
]
=
{}
def
reset_mm_cache
(
self
)
->
None
:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def
reset_encoder_cache
(
self
)
->
None
:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self
.
encoder_cache
.
clear
()
def
add_request
(
self
,
req_id
:
str
,
mm_features
:
list
[
MultiModalFeatureSpec
]):
self
.
req_id_to_mm_features
[
req_id
]
=
mm_features
def
free_encoder_cache
(
self
,
mm_hash
:
str
)
->
None
:
self
.
encoder_cache
.
pop
(
mm_hash
,
None
)
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
req_id_to_mm_features
.
pop
(
req_id
,
None
)
def
prepare_mm_inputs
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]]
...
...
@@ -59,7 +35,7 @@ class EncoderRunner:
mm_hashes
:
list
[
str
]
=
[]
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
mm_features
=
self
.
req_id_to_
mm_features
[
req_id
]
mm_features
=
self
.
encoder_cache
.
mm_features
[
req_id
]
for
mm_input_id
in
encoder_input_ids
:
mm_feature
=
mm_features
[
mm_input_id
]
if
mm_feature
.
data
is
None
:
...
...
@@ -90,7 +66,7 @@ class EncoderRunner:
encoder_outputs
.
extend
(
curr_group_outputs
)
# Cache the encoder outputs by mm_hash
self
.
encoder_cache
.
update
(
zip
(
mm_hashes
,
encoder_outputs
))
self
.
encoder_cache
.
encoder_outputs
.
update
(
zip
(
mm_hashes
,
encoder_outputs
))
return
encoder_outputs
def
gather_mm_embeddings
(
...
...
@@ -122,7 +98,7 @@ class EncoderRunner:
# OPTIMIZATION: Skip decode requests.
continue
mm_features
=
self
.
req_id_to_
mm_features
[
req_id
]
mm_features
=
self
.
encoder_cache
.
mm_features
[
req_id
]
for
mm_feature
in
mm_features
:
pos_info
=
mm_feature
.
mm_position
start_pos
=
pos_info
.
offset
...
...
@@ -148,7 +124,7 @@ class EncoderRunner:
continue
mm_hash
=
mm_feature
.
identifier
encoder_output
=
self
.
encoder_cache
.
get
(
mm_hash
,
None
)
encoder_output
=
self
.
encoder_cache
.
encoder_outputs
.
get
(
mm_hash
,
None
)
assert
encoder_output
is
not
None
,
f
"Encoder cache miss for
{
mm_hash
}
."
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
1a014a0a
...
...
@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import (
get_kv_connector
,
)
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_
runner
import
Encoder
Runner
from
vllm.v1.worker.gpu.mm.encoder_
cache
import
Encoder
Cache
from
vllm.v1.worker.gpu.model_states
import
ModelState
from
vllm.v1.worker.gpu.pool.pooling_runner
import
PoolingRunner
from
vllm.v1.worker.gpu.pp_utils
import
pp_broadcast
,
pp_receive
...
...
@@ -127,20 +127,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
inputs_embeds_size
=
self
.
model_config
.
get_inputs_embeds_size
()
# Multimodal
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
self
.
model_config
)
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
=
EncoderRunner
(
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
inputs_embeds_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
self
.
use_async_scheduling
=
self
.
scheduler_config
.
async_scheduling
self
.
output_copy_stream
=
torch
.
cuda
.
Stream
(
self
.
device
)
...
...
@@ -162,6 +148,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
if
self
.
use_dcp
else
0
self
.
cp_interleave
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
# Multimodal
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
self
.
model_config
)
self
.
encoder_cache
=
None
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
:
self
.
encoder_cache
=
EncoderCache
()
self
.
speculator
=
None
self
.
num_speculative_steps
=
0
self
.
use_aux_hidden_state_outputs
=
False
...
...
@@ -272,7 +267,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model
(
self
.
speculator
)
# Initialize the components that require the model.
self
.
model_state
=
ModelState
(
self
.
vllm_config
,
self
.
model
,
self
.
device
)
self
.
model_state
=
ModelState
(
self
.
vllm_config
,
self
.
model
,
self
.
encoder_cache
,
self
.
device
)
if
self
.
is_pooling_model
:
self
.
pooling_runner
=
PoolingRunner
(
self
.
model
)
...
...
@@ -435,12 +432,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc
.
collect
()
def
reset_mm_cache
(
self
)
->
None
:
if
self
.
supports_mm_inputs
:
self
.
encoder_
runner
.
reset_mm_cache
()
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_
cache
.
reset_mm_cache
()
def
reset_encoder_cache
(
self
)
->
None
:
if
self
.
supports_mm_inputs
:
self
.
encoder_
runner
.
reset_encoder_cache
()
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_
cache
.
reset_encoder_cache
()
def
_get_num_input_tokens
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
# SP is not supported yet.
...
...
@@ -469,14 +466,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
with
self
.
maybe_setup_dummy_loras
(
self
.
lora_config
):
inputs_embeds
=
None
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
self
.
cudagraph_manager
.
capture
(
model
=
self
.
model
,
model_state
=
self
.
model_state
,
input_buffers
=
self
.
input_buffers
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
self
.
block_tables
,
attn_groups
=
self
.
attn_groups
,
kv_cache_config
=
self
.
kv_cache_config
,
...
...
@@ -511,15 +504,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
finished_req_ids
=
finished_req_ids
.
union
(
preempted_req_ids
)
for
req_id
in
finished_req_ids
:
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
supports_mm_inputs
:
self
.
encoder_
runner
.
remove_request
(
req_id
)
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_
cache
.
remove_request
(
req_id
)
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
if
self
.
supports_mm_inputs
:
if
self
.
encoder_cache
is
not
None
:
for
mm_hash
in
scheduler_output
.
free_encoder_mm_hashes
:
self
.
encoder_
runner
.
free_encoder_cache
(
mm_hash
)
self
.
encoder_
cache
.
free_encoder_cache
(
mm_hash
)
def
add_requests
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
...
...
@@ -535,8 +528,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
if
self
.
supports_mm_inputs
:
self
.
encoder_
runner
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_
cache
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
self
.
model_state
.
add_request
(
req_index
,
new_req_data
)
self
.
block_tables
.
append_block_ids
(
...
...
@@ -695,7 +688,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dcp_local_seq_lens
=
dcp_local_seq_lens
,
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
],
positions
=
self
.
input_buffers
.
positions
[:
num_tokens_after_padding
],
inputs_embeds
=
None
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
...
...
@@ -724,26 +716,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return
block_tables
,
slot_mappings
@
torch
.
inference_mode
()
def
get_mm_embeddings
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
mm_hashes
,
mm_kwargs
=
self
.
encoder_runner
.
prepare_mm_inputs
(
scheduled_encoder_inputs
)
self
.
encoder_runner
.
execute_mm_encoder
(
self
.
model
,
mm_hashes
,
mm_kwargs
)
mm_embeds
,
is_mm_embed
=
self
.
encoder_runner
.
gather_mm_embeddings
(
input_batch
.
req_ids
,
input_batch
.
num_tokens
,
input_batch
.
num_scheduled_tokens
,
input_batch
.
query_start_loc_np
,
self
.
req_states
.
prefill_len
.
np
[
input_batch
.
idx_mapping_np
],
self
.
req_states
.
num_computed_prefill_tokens
[
input_batch
.
idx_mapping_np
],
)
return
mm_embeds
,
is_mm_embed
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -890,18 +862,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch
.
num_scheduled_tokens
,
)
self
.
_set_active_loras
(
*
lora_inputs
)
# Only first PP rank prepares multimodal embeddings.
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
:
mm_embeds
,
is_mm_embed
=
self
.
get_mm_embeddings
(
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
)
inputs_embeds
=
self
.
encoder_runner
.
get_inputs_embeds
(
self
.
model
,
input_batch
.
input_ids
,
mm_embeds
,
is_mm_embed
)
input_batch
.
inputs_embeds
=
inputs_embeds
[
:
input_batch
.
num_tokens_after_padding
]
else
:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs
=
min
(
num_tokens_after_padding
,
self
.
max_num_reqs
)
...
...
@@ -934,10 +894,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_config
,
)
inputs_embeds
=
None
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
and
not
dummy_run
:
# Run MM encoder (if needed) and get multimodal embeddings.
# Only first PP rank prepares multimodal embeddings.
inputs_embeds
=
self
.
model_state
.
get_mm_embeddings
(
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
,
self
.
req_states
,
)
model_inputs
=
{
"input_ids"
:
input_batch
.
input_ids
,
"positions"
:
input_batch
.
positions
,
"inputs_embeds"
:
input_batch
.
inputs_embeds
,
"inputs_embeds"
:
inputs_embeds
,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**
self
.
model_state
.
prepare_inputs
(
input_batch
,
self
.
req_states
),
...
...
vllm/v1/worker/gpu/model_states.py
View file @
1a014a0a
...
...
@@ -10,22 +10,43 @@ from vllm.v1.core.sched.output import NewRequestData
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.mm.encoder_cache
import
EncoderCache
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.utils
import
AttentionGroup
class
ModelState
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
device
:
torch
.
device
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
encoder_cache
:
EncoderCache
|
None
,
device
:
torch
.
device
,
):
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
model
=
model
self
.
device
=
device
self
.
supports_mm_inputs
=
encoder_cache
is
not
None
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
inputs_embeds_size
=
self
.
model_config
.
get_inputs_embeds_size
()
self
.
dtype
=
self
.
model_config
.
dtype
if
self
.
supports_mm_inputs
:
assert
encoder_cache
is
not
None
self
.
encoder_runner
=
EncoderRunner
(
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
inputs_embeds_size
,
encoder_cache
=
encoder_cache
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
if
self
.
uses_mrope
:
...
...
@@ -51,6 +72,29 @@ class ModelState:
if
self
.
uses_mrope
:
self
.
mrope_state
.
apply_staged_writes
()
def
get_mm_embeddings
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
req_states
:
RequestState
,
)
->
torch
.
Tensor
:
mm_hashes
,
mm_kwargs
=
self
.
encoder_runner
.
prepare_mm_inputs
(
scheduled_encoder_inputs
)
self
.
encoder_runner
.
execute_mm_encoder
(
self
.
model
,
mm_hashes
,
mm_kwargs
)
mm_embeds
,
is_mm_embed
=
self
.
encoder_runner
.
gather_mm_embeddings
(
input_batch
.
req_ids
,
input_batch
.
num_tokens
,
input_batch
.
num_scheduled_tokens
,
input_batch
.
query_start_loc_np
,
req_states
.
prefill_len
.
np
[
input_batch
.
idx_mapping_np
],
req_states
.
num_computed_prefill_tokens
[
input_batch
.
idx_mapping_np
],
)
inputs_embeds
=
self
.
encoder_runner
.
get_inputs_embeds
(
self
.
model
,
input_batch
.
input_ids
,
mm_embeds
,
is_mm_embed
)
return
inputs_embeds
[:
input_batch
.
num_tokens_after_padding
]
def
prepare_inputs
(
self
,
input_batch
:
InputBatch
,
req_states
:
RequestState
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
...
...
@@ -73,10 +117,14 @@ class ModelState:
def
prepare_dummy_inputs
(
self
,
num_reqs
:
int
,
num_tokens
:
int
)
->
dict
[
str
,
torch
.
Tensor
|
None
]:
if
not
self
.
uses_mrope
:
return
{}
model_inputs
=
{}
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
[:
num_tokens
]
model_inputs
[
"inputs_embeds"
]
=
inputs_embeds
if
self
.
uses_mrope
:
mrope_positions
=
self
.
mrope_state
.
mrope_positions
[:,
:
num_tokens
]
return
{
"positions"
:
mrope_positions
}
model_inputs
[
"positions"
]
=
mrope_positions
return
model_inputs
def
prepare_attn
(
self
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
View file @
1a014a0a
...
...
@@ -44,7 +44,6 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
self
.
vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
self
.
dtype
=
vllm_config
.
model_config
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment