Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb1848cd
Unverified
Commit
bb1848cd
authored
Jan 18, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 18, 2026
Browse files
[Model Runner V2] Support VLM (#32546)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
6101a26d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
263 additions
and
15 deletions
+263
-15
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+7
-0
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+3
-3
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+184
-0
vllm/v1/worker/gpu/mm/mrope_utils.py
vllm/v1/worker/gpu/mm/mrope_utils.py
+3
-5
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+66
-4
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+0
-3
No files found.
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
bb1848cd
...
@@ -76,6 +76,7 @@ class CudaGraphManager:
...
@@ -76,6 +76,7 @@ class CudaGraphManager:
model
:
nn
.
Module
,
model
:
nn
.
Module
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
mrope_positions
:
torch
.
Tensor
|
None
,
mrope_positions
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
...
@@ -86,6 +87,8 @@ class CudaGraphManager:
...
@@ -86,6 +87,8 @@ class CudaGraphManager:
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
assert
mrope_positions
is
not
None
assert
mrope_positions
is
not
None
positions
=
mrope_positions
[:,
:
num_tokens
]
positions
=
mrope_positions
[:,
:
num_tokens
]
if
inputs_embeds
is
not
None
:
inputs_embeds
=
inputs_embeds
[:
num_tokens
]
attn_metadata
=
prepare_inputs_to_capture
(
attn_metadata
=
prepare_inputs_to_capture
(
num_reqs
,
num_reqs
,
num_tokens
,
num_tokens
,
...
@@ -108,6 +111,7 @@ class CudaGraphManager:
...
@@ -108,6 +111,7 @@ class CudaGraphManager:
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
)
if
self
.
hidden_states
is
None
:
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
...
@@ -128,6 +132,7 @@ class CudaGraphManager:
...
@@ -128,6 +132,7 @@ class CudaGraphManager:
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
)
)
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
self
.
graphs
[
num_tokens
]
=
graph
self
.
graphs
[
num_tokens
]
=
graph
...
@@ -138,6 +143,7 @@ class CudaGraphManager:
...
@@ -138,6 +143,7 @@ class CudaGraphManager:
model
:
nn
.
Module
,
model
:
nn
.
Module
,
input_buffers
:
InputBuffers
,
input_buffers
:
InputBuffers
,
mrope_positions
:
torch
.
Tensor
|
None
,
mrope_positions
:
torch
.
Tensor
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
,
block_tables
:
BlockTables
,
block_tables
:
BlockTables
,
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
...
@@ -149,6 +155,7 @@ class CudaGraphManager:
...
@@ -149,6 +155,7 @@ class CudaGraphManager:
model
=
model
,
model
=
model
,
input_buffers
=
input_buffers
,
input_buffers
=
input_buffers
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
attn_metadata_builders
=
attn_metadata_builders
,
attn_metadata_builders
=
attn_metadata_builders
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
bb1848cd
...
@@ -15,9 +15,6 @@ class InputBuffers:
...
@@ -15,9 +15,6 @@ class InputBuffers:
self
,
self
,
max_num_reqs
:
int
,
max_num_reqs
:
int
,
max_num_tokens
:
int
,
max_num_tokens
:
int
,
inputs_embeds_size
:
int
,
vocab_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
):
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
...
@@ -64,6 +61,8 @@ class InputBatch:
...
@@ -64,6 +61,8 @@ class InputBatch:
positions
:
torch
.
Tensor
positions
:
torch
.
Tensor
# [3, num_tokens_after_padding]
# [3, num_tokens_after_padding]
mrope_positions
:
torch
.
Tensor
|
None
mrope_positions
:
torch
.
Tensor
|
None
# [num_tokens_after_padding, hidden_size]
inputs_embeds
:
torch
.
Tensor
|
None
# layer_name -> Metadata
# layer_name -> Metadata
attn_metadata
:
dict
[
str
,
Any
]
attn_metadata
:
dict
[
str
,
Any
]
...
@@ -132,6 +131,7 @@ class InputBatch:
...
@@ -132,6 +131,7 @@ class InputBatch:
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
mrope_positions
=
None
,
mrope_positions
=
None
,
inputs_embeds
=
None
,
attn_metadata
=
None
,
# type: ignore
attn_metadata
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits
=
cu_num_logits
,
...
...
vllm/v1/worker/gpu/mm/encoder_runner.py
0 → 100644
View file @
bb1848cd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItem
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.utils
import
sanity_check_mm_encoder_outputs
class
EncoderRunner
:
def
__init__
(
self
,
max_num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
self
.
max_num_tokens
=
max_num_tokens
self
.
hidden_size
=
hidden_size
self
.
dtype
=
dtype
self
.
device
=
device
self
.
inputs_embeds
=
torch
.
zeros
(
max_num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
,
)
self
.
req_id_to_mm_features
:
dict
[
str
,
list
[
MultiModalFeatureSpec
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
tmp_is_mm_embed
=
UvaBufferPool
(
max_num_tokens
,
torch
.
bool
)
def
add_request
(
self
,
req_id
:
str
,
mm_features
:
list
[
MultiModalFeatureSpec
]):
self
.
req_id_to_mm_features
[
req_id
]
=
mm_features
def
free_encoder_cache
(
self
,
mm_hash
:
str
)
->
None
:
self
.
encoder_cache
.
pop
(
mm_hash
,
None
)
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
req_id_to_mm_features
.
pop
(
req_id
,
None
)
def
prepare_mm_inputs
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
)
->
tuple
[
list
[
str
],
list
[
MultiModalKwargsItem
]]:
mm_hashes
:
list
[
str
]
=
[]
mm_kwargs
:
list
[
MultiModalKwargsItem
]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
mm_features
=
self
.
req_id_to_mm_features
[
req_id
]
for
mm_input_id
in
encoder_input_ids
:
mm_feature
=
mm_features
[
mm_input_id
]
if
mm_feature
.
data
is
None
:
continue
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_kwargs
.
append
(
mm_feature
.
data
)
return
mm_hashes
,
mm_kwargs
@
torch
.
inference_mode
()
def
execute_mm_encoder
(
self
,
model
:
SupportsMultiModal
,
mm_hashes
:
list
[
str
],
mm_kwargs
:
list
[
MultiModalKwargsItem
],
)
->
list
[
torch
.
Tensor
]:
if
not
mm_hashes
:
return
[]
encoder_outputs
:
list
[
torch
.
Tensor
]
=
[]
for
modality
,
num_items
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
mm_kwargs
,
device
=
self
.
device
,
pin_memory
=
False
,
):
curr_group_outputs
=
model
.
embed_multimodal
(
**
mm_kwargs_group
)
sanity_check_mm_encoder_outputs
(
curr_group_outputs
,
expected_num_items
=
num_items
,
)
encoder_outputs
.
extend
(
curr_group_outputs
)
# Cache the encoder outputs by mm_hash
for
mm_hash
,
output
in
zip
(
mm_hashes
,
encoder_outputs
):
self
.
encoder_cache
[
mm_hash
]
=
output
return
encoder_outputs
def
gather_mm_embeddings
(
self
,
req_ids
:
list
[
str
],
total_num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
np
.
ndarray
,
query_start_loc
:
np
.
ndarray
,
prefill_lens
:
np
.
ndarray
,
computed_prefill_lens
:
np
.
ndarray
,
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
is_prefilling
=
(
computed_prefill_lens
<
prefill_lens
).
tolist
()
all_decode
=
not
any
(
is_prefilling
)
if
all_decode
:
# All decode requests, so no need to gather any embeddings.
return
[],
torch
.
zeros
(
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
)
query_start
=
computed_prefill_lens
.
tolist
()
query_end
=
(
computed_prefill_lens
+
num_scheduled_tokens
).
tolist
()
mm_embeds
:
list
[
torch
.
Tensor
]
=
[]
is_mm_embed
=
torch
.
zeros
(
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
False
,
)
for
i
,
req_id
in
enumerate
(
req_ids
):
if
not
is_prefilling
[
i
]:
# OPTIMIZATION: Skip decode requests.
continue
mm_features
=
self
.
req_id_to_mm_features
[
req_id
]
for
mm_feature
in
mm_features
:
pos_info
=
mm_feature
.
mm_position
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
if
start_pos
>=
query_end
[
i
]:
# The encoder output is not needed in this step.
break
if
start_pos
+
num_encoder_tokens
<=
query_start
[
i
]:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx
=
max
(
query_start
[
i
]
-
start_pos
,
0
)
end_idx
=
min
(
query_end
[
i
]
-
start_pos
,
num_encoder_tokens
)
assert
start_idx
<
end_idx
curr_embeds_start
,
curr_embeds_end
=
(
pos_info
.
get_embeds_indices_in_range
(
start_idx
,
end_idx
)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if
curr_embeds_start
==
curr_embeds_end
:
continue
mm_hash
=
mm_feature
.
identifier
encoder_output
=
self
.
encoder_cache
.
get
(
mm_hash
,
None
)
assert
encoder_output
is
not
None
,
f
"Encoder cache miss for
{
mm_hash
}
."
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
encoder_output
[
curr_embeds_start
:
curr_embeds_end
]
else
:
mm_embeds_item
=
encoder_output
[
start_idx
:
end_idx
]
req_start_pos
=
query_start_loc
[
i
]
+
start_pos
-
query_start
[
i
]
is_mm_embed
[
req_start_pos
+
start_idx
:
req_start_pos
+
end_idx
]
=
(
True
if
is_embed
is
None
else
is_embed
)
mm_embeds
.
append
(
mm_embeds_item
)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed
=
self
.
tmp_is_mm_embed
.
copy_to_gpu
(
is_mm_embed
)
return
mm_embeds
,
is_mm_embed
@
torch
.
inference_mode
()
def
get_inputs_embeds
(
self
,
model
:
SupportsMultiModal
,
input_ids
:
torch
.
Tensor
,
mm_embeds
:
list
[
torch
.
Tensor
],
is_mm_embed
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x
=
model
.
embed_input_ids
(
input_ids
,
multimodal_embeddings
=
mm_embeds
,
is_multimodal
=
is_mm_embed
,
)
# Copy to the pre-allocated buffer for CUDA graphs.
self
.
inputs_embeds
[:
x
.
shape
[
0
]]
=
x
return
self
.
inputs_embeds
vllm/v1/worker/gpu/mm/mrope_utils.py
View file @
bb1848cd
...
@@ -23,7 +23,7 @@ class MRopeState:
...
@@ -23,7 +23,7 @@ class MRopeState:
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
# wasting a lot of CPU memory.
self
.
prefill_mrope_positions
=
StagedWriteTensor
(
self
.
prefill_mrope_positions
=
StagedWriteTensor
(
(
max_num_reqs
,
3
*
max_model_len
),
(
max_num_reqs
*
3
,
max_model_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
uva_instead_of_gpu
=
True
,
uva_instead_of_gpu
=
True
,
...
@@ -58,9 +58,7 @@ class MRopeState:
...
@@ -58,9 +58,7 @@ class MRopeState:
)
)
for
i
in
range
(
3
):
for
i
in
range
(
3
):
pos
=
prefill_mrope_positions
[
i
].
tolist
()
pos
=
prefill_mrope_positions
[
i
].
tolist
()
self
.
prefill_mrope_positions
.
stage_write
(
self
.
prefill_mrope_positions
.
stage_write
(
3
*
req_idx
+
i
,
0
,
pos
)
req_idx
,
i
*
self
.
max_model_len
,
pos
)
self
.
prefill_mrope_delta
.
np
[
req_idx
]
=
prefill_mrope_delta
self
.
prefill_mrope_delta
.
np
[
req_idx
]
=
prefill_mrope_delta
def
apply_staged_writes
(
self
)
->
None
:
def
apply_staged_writes
(
self
)
->
None
:
...
@@ -79,7 +77,7 @@ class MRopeState:
...
@@ -79,7 +77,7 @@ class MRopeState:
self
.
mrope_positions
,
self
.
mrope_positions
,
self
.
mrope_positions
.
stride
(
0
),
self
.
mrope_positions
.
stride
(
0
),
self
.
prefill_mrope_positions
.
gpu
,
self
.
prefill_mrope_positions
.
gpu
,
self
.
prefill_mrope_positions
.
gpu
.
stride
(
0
)
,
3
*
self
.
max_model_len
,
self
.
max_model_len
,
self
.
max_model_len
,
self
.
prefill_mrope_delta
.
gpu
,
self
.
prefill_mrope_delta
.
gpu
,
idx_mapping
,
idx_mapping
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
bb1848cd
...
@@ -14,6 +14,7 @@ from vllm.config.compilation import CUDAGraphMode
...
@@ -14,6 +14,7 @@ from vllm.config.compilation import CUDAGraphMode
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
...
@@ -47,6 +48,7 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -47,6 +48,7 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_pos_seq_lens
,
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
prepare_prefill_inputs
,
)
)
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
...
@@ -95,6 +97,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -95,6 +97,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
inputs_embeds_size
=
self
.
model_config
.
get_inputs_embeds_size
()
self
.
inputs_embeds_size
=
self
.
model_config
.
get_inputs_embeds_size
()
# Multimodal
# Multimodal
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
self
.
model_config
)
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
=
EncoderRunner
(
max_num_tokens
=
self
.
max_num_tokens
,
hidden_size
=
self
.
inputs_embeds_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
self
.
uses_mrope
=
self
.
model_config
.
uses_mrope
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
self
.
mrope_states
=
MRopeState
(
self
.
mrope_states
=
MRopeState
(
...
@@ -134,9 +147,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -134,9 +147,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
input_buffers
=
InputBuffers
(
self
.
input_buffers
=
InputBuffers
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
max_num_tokens
=
self
.
max_num_tokens
,
inputs_embeds_size
=
self
.
inputs_embeds_size
,
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
sampler
=
Sampler
(
self
.
sampler
=
Sampler
(
...
@@ -289,6 +299,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -289,6 +299,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
mrope_positions
=
self
.
mrope_states
.
mrope_positions
[
input_batch
.
mrope_positions
=
self
.
mrope_states
.
mrope_positions
[
:,
:
num_tokens
:,
:
num_tokens
]
]
if
self
.
supports_mm_inputs
:
input_batch
.
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
[:
num_tokens
]
if
not
skip_attn
:
if
not
skip_attn
:
self
.
prepare_dummy_attn_metadata
(
input_batch
)
self
.
prepare_dummy_attn_metadata
(
input_batch
)
...
@@ -314,6 +326,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -314,6 +326,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
input_ids
=
input_batch
.
input_ids
,
positions
=
positions
,
positions
=
positions
,
inputs_embeds
=
input_batch
.
inputs_embeds
,
)
)
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
return
hidden_states
,
sample_hidden_states
return
hidden_states
,
sample_hidden_states
...
@@ -378,10 +391,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -378,10 +391,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mrope_positions
=
None
mrope_positions
=
None
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
mrope_positions
=
self
.
mrope_states
.
mrope_positions
mrope_positions
=
self
.
mrope_states
.
mrope_positions
inputs_embeds
=
None
if
self
.
supports_mm_inputs
:
inputs_embeds
=
self
.
encoder_runner
.
inputs_embeds
self
.
cudagraph_manager
.
capture
(
self
.
cudagraph_manager
.
capture
(
model
=
self
.
model
,
model
=
self
.
model
,
input_buffers
=
self
.
input_buffers
,
input_buffers
=
self
.
input_buffers
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
inputs_embeds
,
block_tables
=
self
.
block_tables
,
block_tables
=
self
.
block_tables
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
kv_cache_config
=
self
.
kv_cache_config
,
kv_cache_config
=
self
.
kv_cache_config
,
...
@@ -412,8 +429,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -412,8 +429,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
scheduler_output
.
preempted_req_ids
is
not
None
:
if
scheduler_output
.
preempted_req_ids
is
not
None
:
for
req_id
in
scheduler_output
.
preempted_req_ids
:
for
req_id
in
scheduler_output
.
preempted_req_ids
:
self
.
req_states
.
remove_request
(
req_id
)
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
remove_request
(
req_id
)
for
req_id
in
scheduler_output
.
finished_req_ids
:
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
req_states
.
remove_request
(
req_id
)
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
remove_request
(
req_id
)
if
self
.
supports_mm_inputs
:
for
mm_hash
in
scheduler_output
.
free_encoder_mm_hashes
:
self
.
encoder_runner
.
free_encoder_cache
(
mm_hash
)
# Add new requests.
# Add new requests.
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
...
@@ -432,13 +457,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -432,13 +457,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
add_request
(
req_id
,
new_req_data
.
mm_features
)
# Pre-compute M-RoPE positions for prefill.
# Pre-compute M-RoPE positions for prefill.
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
self
.
mrope_states
.
init_prefill_mrope_positions
(
self
.
mrope_states
.
init_prefill_mrope_positions
(
req_index
,
req_index
,
self
.
model
,
# type: ignore
self
.
model
,
# type: ignore
new_req_data
.
prefill_token_ids
,
new_req_data
.
prefill_token_ids
,
mm_features
=
[],
# TODO
mm_features
=
new_req_data
.
mm_features
,
)
)
self
.
block_tables
.
append_block_ids
(
self
.
block_tables
.
append_block_ids
(
...
@@ -632,12 +660,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -632,12 +660,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
inputs_embeds
=
None
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
cu_num_logits_np
=
cu_num_logits_np
,
)
)
@
torch
.
inference_mode
()
def
get_mm_embeddings
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
input_batch
:
InputBatch
,
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
mm_hashes
,
mm_kwargs
=
self
.
encoder_runner
.
prepare_mm_inputs
(
scheduled_encoder_inputs
)
self
.
encoder_runner
.
execute_mm_encoder
(
self
.
model
,
mm_hashes
,
mm_kwargs
)
mm_embeds
,
is_mm_embed
=
self
.
encoder_runner
.
gather_mm_embeddings
(
input_batch
.
req_ids
,
input_batch
.
num_tokens
,
input_batch
.
num_scheduled_tokens
,
input_batch
.
query_start_loc_np
,
self
.
req_states
.
prefill_len
.
np
[
input_batch
.
idx_mapping_np
],
self
.
req_states
.
num_computed_prefill_tokens
[
input_batch
.
idx_mapping_np
],
)
return
mm_embeds
,
is_mm_embed
def
sample
(
def
sample
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -930,6 +979,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -930,6 +979,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
num_scheduled_tokens
,
input_batch
.
num_scheduled_tokens
,
)
)
self
.
_set_active_loras
(
*
lora_inputs
)
self
.
_set_active_loras
(
*
lora_inputs
)
if
self
.
supports_mm_inputs
:
# Execute the multimodal encoder.
mm_embeds
,
is_mm_embed
=
self
.
get_mm_embeddings
(
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
)
inputs_embeds
=
self
.
encoder_runner
.
get_inputs_embeds
(
self
.
model
,
input_batch
.
input_ids
,
mm_embeds
,
is_mm_embed
)
input_batch
.
inputs_embeds
=
inputs_embeds
[
:
input_batch
.
num_tokens_after_padding
]
else
:
else
:
# No actual tokens to run. A dummy run for DP.
# No actual tokens to run. A dummy run for DP.
num_reqs
=
min
(
num_tokens_after_padding
,
self
.
max_num_reqs
)
num_reqs
=
min
(
num_tokens_after_padding
,
self
.
max_num_reqs
)
...
@@ -970,6 +1031,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -970,6 +1031,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
input_ids
=
input_batch
.
input_ids
,
positions
=
positions
,
positions
=
positions
,
inputs_embeds
=
input_batch
.
inputs_embeds
,
)
)
self
.
execute_model_state
=
hidden_states
,
input_batch
self
.
execute_model_state
=
hidden_states
,
input_batch
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
bb1848cd
...
@@ -48,9 +48,6 @@ class EagleSpeculator:
...
@@ -48,9 +48,6 @@ class EagleSpeculator:
self
.
input_buffers
=
InputBuffers
(
self
.
input_buffers
=
InputBuffers
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_num_tokens
=
self
.
max_num_tokens
,
max_num_tokens
=
self
.
max_num_tokens
,
inputs_embeds_size
=
self
.
inputs_embeds_size
,
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
device
=
device
,
device
=
device
,
)
)
self
.
hidden_states
=
torch
.
zeros
(
self
.
hidden_states
=
torch
.
zeros
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment