Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0cdbe7b7
Unverified
Commit
0cdbe7b7
authored
Oct 31, 2025
by
Nick Hill
Committed by
GitHub
Nov 01, 2025
Browse files
[Core] Async scheduling + structured outputs compatibility (#26866)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
df334868
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
37 deletions
+141
-37
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+68
-11
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+12
-5
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+35
-10
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+6
-7
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+20
-4
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
0cdbe7b7
...
@@ -109,6 +109,7 @@ from vllm.v1.outputs import (
...
@@ -109,6 +109,7 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT
,
EMPTY_MODEL_RUNNER_OUTPUT
,
AsyncModelRunnerOutput
,
AsyncModelRunnerOutput
,
DraftTokenIds
,
DraftTokenIds
,
KVConnectorOutput
,
LogprobsLists
,
LogprobsLists
,
LogprobsTensors
,
LogprobsTensors
,
ModelRunnerOutput
,
ModelRunnerOutput
,
...
@@ -150,7 +151,7 @@ from .utils import (
...
@@ -150,7 +151,7 @@ from .utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return
output
return
output
class
ExecuteModelState
(
NamedTuple
):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output
:
"SchedulerOutput"
logits
:
torch
.
Tensor
spec_decode_metadata
:
SpecDecodeMetadata
|
None
spec_decode_common_attn_metadata
:
CommonAttentionMetadata
|
None
hidden_states
:
torch
.
Tensor
sample_hidden_states
:
torch
.
Tensor
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
kv_connector_output
:
KVConnectorOutput
|
None
class
GPUModelRunner
(
LoRAModelRunnerMixin
,
KVConnectorModelRunnerMixin
):
class
GPUModelRunner
(
LoRAModelRunnerMixin
,
KVConnectorModelRunnerMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
# Ephemeral state transferred between execute_model() and sample_tokens().
self
.
execute_model_state
:
ExecuteModelState
|
None
=
None
def
reset_mm_cache
(
self
)
->
None
:
def
reset_mm_cache
(
self
)
->
None
:
if
self
.
mm_budget
:
if
self
.
mm_budget
:
self
.
mm_budget
.
reset_cache
()
self
.
mm_budget
.
reset_cache
()
...
@@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_input_tokens
:
int
,
# Padded
num_input_tokens
:
int
,
# Padded
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
)
->
tuple
[
)
->
tuple
[
int
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs
.
update
(
encoder_inputs
)
model_kwargs
.
update
(
encoder_inputs
)
return
(
return
(
num_scheduled_tokens
,
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
positions
,
positions
,
...
@@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
:
)
->
ModelRunnerOutput
|
IntermediateTensors
|
None
:
if
self
.
execute_model_state
is
not
None
:
raise
RuntimeError
(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
with
record_function_or_nullcontext
(
"Preprocess"
):
with
record_function_or_nullcontext
(
"Preprocess"
):
with
self
.
synchronize_input_prep
():
with
self
.
synchronize_input_prep
():
# Update persistent batch states.
# Update persistent batch states.
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_
num_scheduled_tokens
:
if
not
num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
if
not
has_kv_transfer_group
():
# Return empty ModelRunnerOutput if no work to do.
# Return empty ModelRunnerOutput if no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
EMPTY_MODEL_RUNNER_OUTPUT
...
@@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
(
(
num_scheduled_tokens
,
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
positions
,
positions
,
...
@@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Rare case.
# Rare case.
assert
not
self
.
is_pooling_model
assert
not
self
.
is_pooling_model
sample_hidden_states
=
hidden_states
[
logits_indices
]
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
all_gather_tensors
=
{
all_gather_tensors
=
{
"residual"
:
not
is_residual_scattered_for_sp
(
"residual"
:
not
is_residual_scattered_for_sp
(
...
@@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
logits
=
None
logits
=
None
else
:
else
:
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
model_output_broadcast_data
=
{}
model_output_broadcast_data
=
{}
...
@@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert
model_output_broadcast_data
is
not
None
assert
model_output_broadcast_data
is
not
None
logits
=
model_output_broadcast_data
[
"logits"
]
logits
=
model_output_broadcast_data
[
"logits"
]
# Apply structured output bitmasks if present
self
.
execute_model_state
=
ExecuteModelState
(
if
scheduler_output
.
structured_output_request_ids
:
scheduler_output
,
apply_grammar_bitmask
(
scheduler_output
,
self
.
input_batch
,
logits
)
logits
,
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
kv_connector_output
,
)
return
None
@
torch
.
inference_mode
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput | None"
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
:
if
self
.
execute_model_state
is
None
:
# Nothing to do (PP non-final rank case), output isn't used.
return
None
# noqa
# Unpack ephemeral state.
(
scheduler_output
,
logits
,
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
kv_connector_output
,
)
=
self
.
execute_model_state
# Clear ephemeral state.
self
.
execute_model_state
=
None
# Apply structured output bitmasks if present.
if
grammar_output
is
not
None
:
apply_grammar_bitmask
(
scheduler_output
,
grammar_output
,
self
.
input_batch
,
logits
)
with
record_function_or_nullcontext
(
"Sample"
):
with
record_function_or_nullcontext
(
"Sample"
):
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
...
@@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output
,
sampler_output
,
logits
,
logits
,
hidden_states
,
hidden_states
,
num_scheduled_tokens
,
scheduler_output
.
total_
num_scheduled_tokens
,
spec_decode_metadata
,
spec_decode_metadata
,
)
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
0cdbe7b7
...
@@ -6,6 +6,7 @@ import copy
...
@@ -6,6 +6,7 @@ import copy
import
gc
import
gc
import
os
import
os
from
contextlib
import
AbstractContextManager
,
nullcontext
from
contextlib
import
AbstractContextManager
,
nullcontext
from
types
import
NoneType
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
...
@@ -37,6 +38,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -37,6 +38,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
MemorySnapshot
,
memory_profiling
from
vllm.utils.mem_utils
import
MemorySnapshot
,
memory_profiling
from
vllm.v1.core.sched.output
import
GrammarOutput
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.engine
import
ReconfigureDistributedRequest
,
ReconfigureRankType
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
(
from
vllm.v1.outputs
import
(
...
@@ -508,11 +510,16 @@ class Worker(WorkerBase):
...
@@ -508,11 +510,16 @@ class Worker(WorkerBase):
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
return
self
.
model_runner
.
get_supported_tasks
()
return
self
.
model_runner
.
get_supported_tasks
()
@
torch
.
inference_mode
()
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput"
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
:
return
self
.
model_runner
.
sample_tokens
(
grammar_output
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
scheduler_output
:
"SchedulerOutput"
,
)
->
ModelRunnerOutput
|
None
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
None
:
intermediate_tensors
=
None
intermediate_tensors
=
None
forward_pass
=
scheduler_output
.
total_num_scheduled_tokens
>
0
forward_pass
=
scheduler_output
.
total_num_scheduled_tokens
>
0
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
...
@@ -531,13 +538,13 @@ class Worker(WorkerBase):
...
@@ -531,13 +538,13 @@ class Worker(WorkerBase):
)
)
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
isinstance
(
output
,
(
ModelRunnerOutput
,
AsyncModelRunnerOutput
)):
if
isinstance
(
output
,
(
ModelRunnerOutput
,
NoneType
)):
return
output
return
output
assert
isinstance
(
output
,
IntermediateTensors
)
assert
isinstance
(
output
,
IntermediateTensors
)
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
=
self
.
vllm_config
.
parallel_config
assert
(
assert
(
parallel_config
.
distributed_executor_backend
!=
(
"external_launcher"
)
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
and
not
get_pp_group
().
is_last_rank
and
not
get_pp_group
().
is_last_rank
)
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
0cdbe7b7
...
@@ -92,7 +92,7 @@ from .utils import (
...
@@ -92,7 +92,7 @@ from .utils import (
)
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
else
:
self
.
sample_from_logits_func
=
self
.
sample_from_logits
self
.
sample_from_logits_func
=
self
.
sample_from_logits
# For passing scheduler_output between successive
# execute_model() and sample_tokens() calls.
self
.
scheduler_output
:
SchedulerOutput
|
None
=
None
self
.
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
def
reset_mm_cache
(
self
)
->
None
:
def
reset_mm_cache
(
self
)
->
None
:
if
self
.
mm_budget
:
if
self
.
mm_budget
:
self
.
mm_budget
.
reset_cache
()
self
.
mm_budget
.
reset_cache
()
...
@@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
)
->
ModelRunnerOutput
:
)
->
ModelRunnerOutput
|
None
:
if
self
.
scheduler_output
is
not
None
:
raise
RuntimeError
(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
# Update cached state
# Update cached state
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
scheduler_output
.
total_num_scheduled_tokens
:
...
@@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
self
.
kv_connector_no_forward
(
scheduler_output
,
self
.
vllm_config
)
return
self
.
kv_connector_no_forward
(
scheduler_output
,
self
.
vllm_config
)
mm_embed_inputs
=
None
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
# Run the multimodal encoder if any.
# Run the multimodal encoder if any.
self
.
_execute_mm_encoder
(
scheduler_output
)
self
.
_execute_mm_encoder
(
scheduler_output
)
mm_embed_inputs
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
mm_embed_inputs
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
else
:
mm_embed_inputs
=
None
torch_xla
.
sync
(
wait
=
False
)
torch_xla
.
sync
(
wait
=
False
)
self
.
scheduler_output
=
scheduler_output
self
.
mm_embed_inputs
=
mm_embed_inputs
return
None
@
torch
.
no_grad
()
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput | None"
)
->
ModelRunnerOutput
:
if
self
.
scheduler_output
is
None
:
# Nothing to do (PP non-final rank case), output isn't used.
return
None
# noqa
scheduler_output
=
self
.
scheduler_output
mm_embed_inputs
=
self
.
mm_embed_inputs
self
.
scheduler_output
=
None
self
.
mm_embed_inputs
=
None
# Prepare inputs, the requests might be split into multiple
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
# executions, combine the result of each execution.
start_index
=
0
start_index
=
0
...
@@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
from_input_batch
(
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
)
if
scheduler_output
.
grammar_
bitmask
is
not
None
:
if
grammar_
output
is
not
None
:
require_struct_decoding
,
grammar_bitmask_padded
,
arange
=
(
require_struct_decoding
,
grammar_bitmask_padded
,
arange
=
(
self
.
prepare_structured_decoding_input
(
logits
,
schedule
r_output
)
self
.
prepare_structured_decoding_input
(
logits
,
gramma
r_output
)
)
)
logits
=
self
.
structured_decode
(
logits
=
self
.
structured_decode
(
require_struct_decoding
,
grammar_bitmask_padded
,
logits
,
arange
require_struct_decoding
,
grammar_bitmask_padded
,
logits
,
arange
...
@@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
def
prepare_structured_decoding_input
(
def
prepare_structured_decoding_input
(
self
,
logits
:
torch
.
Tensor
,
schedule
r_output
:
"
Schedule
rOutput"
self
,
logits
:
torch
.
Tensor
,
gramma
r_output
:
"
Gramma
rOutput"
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
grammar_bitmask
=
grammar_output
.
grammar_bitmask
assert
grammar_bitmask
is
not
None
num_reqs
,
_
=
logits
.
shape
num_reqs
,
_
=
logits
.
shape
# Reset pre-allocated tensors
# Reset pre-allocated tensors
...
@@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
require_structured_out_cpu
.
zero_
()
self
.
require_structured_out_cpu
.
zero_
()
cumulative_mask_idx
=
0
cumulative_mask_idx
=
0
for
req_id
in
schedule
r_output
.
structured_output_request_ids
:
for
req_id
in
gramma
r_output
.
structured_output_request_ids
:
if
req_id
not
in
self
.
input_batch
.
req_id_to_index
:
if
req_id
not
in
self
.
input_batch
.
req_id_to_index
:
continue
continue
batch_index
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
batch_index
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
...
...
vllm/v1/worker/tpu_worker.py
View file @
0cdbe7b7
...
@@ -17,7 +17,6 @@ from vllm.distributed import (
...
@@ -17,7 +17,6 @@ from vllm.distributed import (
)
)
from
vllm.distributed.kv_transfer
import
(
from
vllm.distributed.kv_transfer
import
(
ensure_kv_transfer_initialized
,
ensure_kv_transfer_initialized
,
has_kv_transfer_group
,
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
...
@@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.utils
import
report_usage_stats
...
@@ -255,13 +254,13 @@ class TPUWorker:
...
@@ -255,13 +254,13 @@ class TPUWorker:
tpu_kv_cache_bytes
=
tpu_kv_cache_bytes
*
head_size
//
padded_head_size
tpu_kv_cache_bytes
=
tpu_kv_cache_bytes
*
head_size
//
padded_head_size
return
int
(
tpu_kv_cache_bytes
)
return
int
(
tpu_kv_cache_bytes
)
def
sample_tokens
(
self
,
grammar_output
:
"GrammarOutput"
)
->
ModelRunnerOutput
:
return
self
.
model_runner
.
sample_tokens
(
grammar_output
)
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
scheduler_output
:
"SchedulerOutput"
,
)
->
ModelRunnerOutput
|
None
:
)
->
ModelRunnerOutput
|
None
:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
return
self
.
model_runner
.
execute_model
(
scheduler_output
)
# every worker's output is needed when kv_transfer_group is set up
return
output
if
self
.
is_driver_worker
or
has_kv_transfer_group
()
else
None
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
if
self
.
rank
<
1
:
if
self
.
rank
<
1
:
...
...
vllm/v1/worker/worker_base.py
View file @
0cdbe7b7
...
@@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
...
@@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
from
vllm.v1.serial_utils
import
run_method
from
vllm.v1.serial_utils
import
run_method
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
AsyncModelRunnerOutput
,
ModelRunnerOutput
else
:
else
:
SchedulerOutput
=
object
SchedulerOutput
=
object
GrammarOutput
=
object
AsyncModelRunnerOutput
=
object
ModelRunnerOutput
=
object
ModelRunnerOutput
=
object
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -122,7 +124,21 @@ class WorkerBase:
...
@@ -122,7 +124,21 @@ class WorkerBase:
"""Load model onto target device."""
"""Load model onto target device."""
raise
NotImplementedError
raise
NotImplementedError
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
)
->
ModelRunnerOutput
:
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
)
->
ModelRunnerOutput
|
None
:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise
NotImplementedError
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
:
"""Should be called immediately after execute_model iff it returned None."""
raise
NotImplementedError
raise
NotImplementedError
def
get_cache_block_size_bytes
(
self
)
->
int
:
def
get_cache_block_size_bytes
(
self
)
->
int
:
...
@@ -344,7 +360,7 @@ class WorkerWrapperBase:
...
@@ -344,7 +360,7 @@ class WorkerWrapperBase:
scheduler_output
:
SchedulerOutput
,
scheduler_output
:
SchedulerOutput
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
ModelRunnerOutput
:
)
->
ModelRunnerOutput
|
None
:
self
.
_apply_mm_cache
(
scheduler_output
)
self
.
_apply_mm_cache
(
scheduler_output
)
assert
self
.
worker
is
not
None
assert
self
.
worker
is
not
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment