Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c2cfb62
Unverified
Commit
6c2cfb62
authored
Dec 31, 2025
by
Nick Hill
Committed by
GitHub
Dec 31, 2025
Browse files
[BugFix] Fix async scheduling for pooling models (#31584)
Signed-off-by:
njhill
<
nickhill123@gmail.com
>
parent
d8da76f3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
131 additions
and
92 deletions
+131
-92
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+5
-2
vllm/v1/outputs.py
vllm/v1/outputs.py
+7
-18
vllm/v1/pool/metadata.py
vllm/v1/pool/metadata.py
+20
-28
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+1
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+92
-35
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-2
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+2
-2
No files found.
vllm/v1/executor/ray_utils.py
View file @
6c2cfb62
...
@@ -104,8 +104,11 @@ try:
...
@@ -104,8 +104,11 @@ try:
scheduler_output
,
intermediate_tensors
scheduler_output
,
intermediate_tensors
)
)
if
isinstance
(
output
,
IntermediateTensors
):
if
isinstance
(
output
,
IntermediateTensors
):
output
=
scheduler_output
,
grammar_output
,
output
return
scheduler_output
,
grammar_output
,
output
elif
not
get_pp_group
().
is_last_rank
:
if
isinstance
(
output
,
AsyncModelRunnerOutput
):
output
=
output
.
get_output
()
if
not
get_pp_group
().
is_last_rank
:
# Case where there are no scheduled requests
# Case where there are no scheduled requests
# but may still be finished requests.
# but may still be finished requests.
assert
not
output
or
not
output
.
req_ids
assert
not
output
or
not
output
.
req_ids
...
...
vllm/v1/outputs.py
View file @
6c2cfb62
...
@@ -151,21 +151,23 @@ class ModelRunnerOutput:
...
@@ -151,21 +151,23 @@ class ModelRunnerOutput:
# num_generated_tokens is the number of tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
# each request due to speculative/jump decoding.
sampled_token_ids
:
list
[
list
[
int
]]
sampled_token_ids
:
list
[
list
[
int
]]
=
field
(
default_factory
=
list
)
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs]
# [num_reqs]
logprobs
:
LogprobsLists
|
None
logprobs
:
LogprobsLists
|
None
=
None
# req_id -> (token_ids, logprobs, ranks)
# req_id -> (token_ids, logprobs, ranks)
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
# [prompt_len]
prompt_logprobs_dict
:
dict
[
str
,
LogprobsTensors
|
None
]
prompt_logprobs_dict
:
dict
[
str
,
LogprobsTensors
|
None
]
=
field
(
default_factory
=
dict
)
# [num_reqs, hidden_size]
# [num_reqs, hidden_size]
pooler_output
:
list
[
torch
.
Tensor
|
None
]
pooler_output
:
list
[
torch
.
Tensor
|
None
]
|
None
=
None
kv_connector_output
:
KVConnectorOutput
|
None
=
None
kv_connector_output
:
KVConnectorOutput
|
None
=
None
...
@@ -225,21 +227,8 @@ def make_empty_encoder_model_runner_output(
...
@@ -225,21 +227,8 @@ def make_empty_encoder_model_runner_output(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
req_id_to_index
=
req_id_to_index
,
req_id_to_index
=
req_id_to_index
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
pooler_output
,
pooler_output
=
pooler_output
,
kv_connector_output
=
None
,
ec_connector_output
=
None
,
num_nans_in_logits
=
None
,
)
)
EMPTY_MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
EMPTY_MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
req_ids
=
[],
req_id_to_index
=
{})
req_ids
=
[],
req_id_to_index
=
{},
sampled_token_ids
=
[],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[],
num_nans_in_logits
=
None
,
)
vllm/v1/pool/metadata.py
View file @
6c2cfb62
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
numpy
as
np
import
torch
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
...
@@ -91,36 +92,27 @@ class PoolingMetadata:
...
@@ -91,36 +92,27 @@ class PoolingMetadata:
def
build_pooling_cursor
(
def
build_pooling_cursor
(
self
,
self
,
num_scheduled_tokens
:
list
[
int
]
,
num_scheduled_tokens
_np
:
np
.
ndarray
,
seq_lens_cpu
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
device
:
torch
.
device
,
device
:
torch
.
device
,
):
):
self
.
pooling_cursor
=
build_pooling_cursor
(
n_seq
=
len
(
num_scheduled_tokens_np
)
num_scheduled_tokens
,
seq_lens_cpu
,
self
.
prompt_lens
,
device
prompt_lens
=
self
.
prompt_lens
)
assert
len
(
prompt_lens
)
==
n_seq
def
build_pooling_cursor
(
index
=
list
(
range
(
n_seq
))
num_scheduled_tokens
:
list
[
int
],
num_scheduled_tokens_cpu
=
torch
.
from_numpy
(
num_scheduled_tokens_np
)
seq_lens_cpu
:
torch
.
Tensor
,
cumsum
=
torch
.
zeros
(
prompt_lens
:
torch
.
Tensor
,
n_seq
+
1
,
dtype
=
torch
.
int64
,
pin_memory
=
pin_memory
,
device
=
"cpu"
device
:
torch
.
device
,
)
):
torch
.
cumsum
(
num_scheduled_tokens_cpu
,
dim
=
0
,
out
=
cumsum
[
1
:])
assert
len
(
prompt_lens
)
==
len
(
num_scheduled_tokens
)
cumsum
=
cumsum
.
to
(
device
,
non_blocking
=
True
)
self
.
pooling_cursor
=
PoolingCursor
(
n_seq
=
len
(
num_scheduled_tokens
)
index
=
index
,
index
=
list
(
range
(
n_seq
))
first_token_indices_gpu
=
cumsum
[:
n_seq
],
num_scheduled_tokens_cpu
=
torch
.
tensor
(
num_scheduled_tokens
,
device
=
"cpu"
)
last_token_indices_gpu
=
cumsum
[
1
:]
-
1
,
cumsum
=
torch
.
zeros
(
prompt_lens_cpu
=
prompt_lens
,
n_seq
+
1
,
dtype
=
torch
.
int64
,
pin_memory
=
pin_memory
,
device
=
"cpu"
seq_lens_cpu
=
seq_lens_cpu
,
)
num_scheduled_tokens_cpu
=
num_scheduled_tokens_cpu
,
torch
.
cumsum
(
num_scheduled_tokens_cpu
,
dim
=
0
,
out
=
cumsum
[
1
:])
)
cumsum
=
cumsum
.
to
(
device
,
non_blocking
=
True
)
return
PoolingCursor
(
index
=
index
,
first_token_indices_gpu
=
cumsum
[:
n_seq
],
last_token_indices_gpu
=
cumsum
[
1
:]
-
1
,
prompt_lens_cpu
=
prompt_lens
,
seq_lens_cpu
=
seq_lens_cpu
,
num_scheduled_tokens_cpu
=
num_scheduled_tokens_cpu
,
)
vllm/v1/worker/gpu/model_runner.py
View file @
6c2cfb62
...
@@ -968,11 +968,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -968,11 +968,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Only for compatibility with the existing model runner and scheduler.
# Only for compatibility with the existing model runner and scheduler.
req_id_to_index
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
)},
req_id_to_index
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
)},
sampled_token_ids
=
None
,
# type: ignore
sampled_token_ids
=
None
,
# type: ignore
logprobs
=
None
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
# type: ignore[arg-type]
prompt_logprobs_dict
=
prompt_logprobs_dict
,
# type: ignore
pooler_output
=
[],
kv_connector_output
=
None
,
num_nans_in_logits
=
None
,
)
)
async_output
=
AsyncOutput
(
async_output
=
AsyncOutput
(
model_runner_output
=
model_runner_output
,
model_runner_output
=
model_runner_output
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6c2cfb62
...
@@ -254,6 +254,50 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -254,6 +254,50 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return
output
return
output
class
AsyncGPUPoolingModelRunnerOutput
(
AsyncModelRunnerOutput
):
def
__init__
(
self
,
model_runner_output
:
ModelRunnerOutput
,
raw_pooler_output
:
PoolerOutput
,
finished_mask
:
list
[
bool
],
async_output_copy_stream
:
torch
.
cuda
.
Stream
,
):
self
.
_model_runner_output
=
model_runner_output
self
.
_finished_mask
=
finished_mask
# Event on the copy stream so we can synchronize the non-blocking copy.
self
.
async_copy_ready_event
=
torch
.
Event
()
# Keep a reference to the device tensors to avoid them being
# deallocated until we finish copying it to the host.
self
.
_raw_pooler_output
=
raw_pooler_output
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
async_output_copy_stream
):
async_output_copy_stream
.
wait_stream
(
default_stream
)
self
.
_raw_pooler_output_cpu
=
json_map_leaves
(
lambda
x
:
None
if
x
is
None
else
x
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
_raw_pooler_output
,
)
self
.
async_copy_ready_event
.
record
()
def
get_output
(
self
)
->
ModelRunnerOutput
:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
This function blocks until the copy is finished.
"""
self
.
async_copy_ready_event
.
synchronize
()
# Release the device tensors once the copy has completed.
del
self
.
_raw_pooler_output
self
.
_model_runner_output
.
pooler_output
=
[
out
if
include
else
None
for
out
,
include
in
zip
(
self
.
_raw_pooler_output_cpu
,
self
.
_finished_mask
)
]
return
self
.
_model_runner_output
class
ExecuteModelState
(
NamedTuple
):
class
ExecuteModelState
(
NamedTuple
):
"""Ephemeral cached state transferred between execute_model() and
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
sample_tokens(), after execute_model() returns None."""
...
@@ -2476,17 +2520,19 @@ class GPUModelRunner(
...
@@ -2476,17 +2520,19 @@ class GPUModelRunner(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens_np
:
np
.
ndarray
,
num_scheduled_tokens_np
:
np
.
ndarray
,
)
->
ModelRunnerOutput
:
kv_connector_output
:
KVConnectorOutput
|
None
,
assert
self
.
input_batch
.
num_reqs
==
len
(
self
.
input_batch
.
pooling_params
),
(
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
:
num_reqs
=
self
.
input_batch
.
num_reqs
assert
num_reqs
==
len
(
self
.
input_batch
.
pooling_params
),
(
"Either all or none of the requests in a batch must be pooling request"
"Either all or none of the requests in a batch must be pooling request"
)
)
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
self
.
input_batch
.
num_reqs
]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs
]
pooling_metadata
=
self
.
input_batch
.
get_pooling_metadata
()
pooling_metadata
=
self
.
input_batch
.
get_pooling_metadata
()
pooling_metadata
.
build_pooling_cursor
(
pooling_metadata
.
build_pooling_cursor
(
num_scheduled_tokens_np
.
tolist
()
,
seq_lens_cpu
,
device
=
hidden_states
.
device
num_scheduled_tokens_np
,
seq_lens_cpu
,
device
=
hidden_states
.
device
)
)
model
=
cast
(
VllmModelForPooling
,
self
.
model
)
model
=
cast
(
VllmModelForPooling
,
self
.
model
)
...
@@ -2494,27 +2540,41 @@ class GPUModelRunner(
...
@@ -2494,27 +2540,41 @@ class GPUModelRunner(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
,
pooling_metadata
=
pooling_metadata
,
)
)
finished_mask
=
[
seq_len
==
prompt_len
for
seq_len
,
prompt_len
in
zip
(
seq_lens_cpu
,
pooling_metadata
.
prompt_lens
)
]
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
.
copy
(),
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
.
copy
(),
kv_connector_output
=
kv_connector_output
,
)
if
raw_pooler_output
is
None
or
not
any
(
finished_mask
):
model_runner_output
.
pooler_output
=
[
None
]
*
num_reqs
return
model_runner_output
if
self
.
use_async_scheduling
:
return
AsyncGPUPoolingModelRunnerOutput
(
model_runner_output
=
model_runner_output
,
raw_pooler_output
=
raw_pooler_output
,
finished_mask
=
finished_mask
,
async_output_copy_stream
=
self
.
async_output_copy_stream
,
)
raw_pooler_output
=
json_map_leaves
(
raw_pooler_output
=
json_map_leaves
(
lambda
x
:
x
.
to
(
"cpu"
,
non_blocking
=
True
)
if
x
is
not
None
else
x
,
lambda
x
:
None
if
x
is
None
else
x
.
to
(
"cpu"
,
non_blocking
=
True
),
raw_pooler_output
,
raw_pooler_output
,
)
)
self
.
_sync_device
()
self
.
_sync_device
()
pooler_output
:
list
[
torch
.
Tensor
|
None
]
=
[]
model_runner_output
.
pooler_output
=
[
for
raw_output
,
seq_len
,
prompt_len
in
zip
(
out
if
include
else
None
raw_pooler_output
,
seq_lens_cpu
,
pooling_metadata
.
prompt_lens
for
out
,
include
in
zip
(
raw_pooler_output
,
finished_mask
)
):
]
output
=
raw_output
if
seq_len
==
prompt_len
else
None
return
model_runner_output
pooler_output
.
append
(
output
)
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
[],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
pooler_output
,
)
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
# Pad tokens to multiple of tensor_parallel_size when
# Pad tokens to multiple of tensor_parallel_size when
...
@@ -3036,7 +3096,7 @@ class GPUModelRunner(
...
@@ -3036,7 +3096,7 @@ class GPUModelRunner(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
)
->
ModelRunnerOutput
|
IntermediateTensors
|
None
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
IntermediateTensors
|
None
:
if
self
.
execute_model_state
is
not
None
:
if
self
.
execute_model_state
is
not
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"State error: sample_tokens() must be called "
"State error: sample_tokens() must be called "
...
@@ -3244,11 +3304,12 @@ class GPUModelRunner(
...
@@ -3244,11 +3304,12 @@ class GPUModelRunner(
if
self
.
is_pooling_model
:
if
self
.
is_pooling_model
:
# Return the pooling output.
# Return the pooling output.
output
=
self
.
_pool
(
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
kv_connector_output
,
)
)
output
.
kv_connector_output
=
kv_connector_output
return
output
sample_hidden_states
=
hidden_states
[
logits_indices
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
...
@@ -3437,7 +3498,6 @@ class GPUModelRunner(
...
@@ -3437,7 +3498,6 @@ class GPUModelRunner(
sampled_token_ids
=
valid_sampled_token_ids
,
sampled_token_ids
=
valid_sampled_token_ids
,
logprobs
=
logprobs_lists
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
kv_connector_output
=
kv_connector_output
,
kv_connector_output
=
kv_connector_output
,
ec_connector_output
=
ec_connector_output
ec_connector_output
=
ec_connector_output
if
self
.
supports_mm_inputs
if
self
.
supports_mm_inputs
...
@@ -4508,17 +4568,14 @@ class GPUModelRunner(
...
@@ -4508,17 +4568,14 @@ class GPUModelRunner(
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
min_tokens_per_req
=
num_tokens
//
num_reqs
min_tokens_per_req
=
num_tokens
//
num_reqs
num_scheduled_tokens_
list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_
np
=
np
.
full
(
num_reqs
,
min_tokens_per_req
)
num_scheduled_tokens_
list
[
-
1
]
+=
num_tokens
%
num_reqs
num_scheduled_tokens_
np
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_scheduled_tokens_
list
)
==
num_tokens
assert
np
.
sum
(
num_scheduled_tokens_
np
)
==
num_tokens
assert
len
(
num_scheduled_tokens_
list
)
==
num_reqs
assert
len
(
num_scheduled_tokens_
np
)
==
num_reqs
req_num_tokens
=
num_tokens
//
num_reqs
req_num_tokens
=
num_tokens
//
num_reqs
dummy_prompt_lens
=
torch
.
tensor
(
dummy_prompt_lens
=
torch
.
from_numpy
(
num_scheduled_tokens_np
)
num_scheduled_tokens_list
,
device
=
"cpu"
,
)
dummy_token_ids
=
torch
.
zeros
(
dummy_token_ids
=
torch
.
zeros
(
(
num_reqs
,
req_num_tokens
),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
num_reqs
,
req_num_tokens
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
...
@@ -4537,7 +4594,7 @@ class GPUModelRunner(
...
@@ -4537,7 +4594,7 @@ class GPUModelRunner(
)
)
dummy_metadata
.
build_pooling_cursor
(
dummy_metadata
.
build_pooling_cursor
(
num_scheduled_tokens_
list
,
num_scheduled_tokens_
np
,
seq_lens_cpu
=
dummy_prompt_lens
,
seq_lens_cpu
=
dummy_prompt_lens
,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
)
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
6c2cfb62
...
@@ -575,7 +575,7 @@ class Worker(WorkerBase):
...
@@ -575,7 +575,7 @@ class Worker(WorkerBase):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
|
None
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
None
:
intermediate_tensors
=
None
intermediate_tensors
=
None
forward_pass
=
scheduler_output
.
total_num_scheduled_tokens
>
0
forward_pass
=
scheduler_output
.
total_num_scheduled_tokens
>
0
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
...
@@ -624,7 +624,9 @@ class Worker(WorkerBase):
...
@@ -624,7 +624,9 @@ class Worker(WorkerBase):
output
=
self
.
model_runner
.
execute_model
(
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
scheduler_output
,
intermediate_tensors
)
)
if
isinstance
(
output
,
ModelRunnerOutput
|
NoneType
):
if
isinstance
(
output
,
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
NoneType
):
return
output
return
output
assert
isinstance
(
output
,
IntermediateTensors
)
assert
isinstance
(
output
,
IntermediateTensors
)
...
...
vllm/v1/worker/worker_base.py
View file @
6c2cfb62
...
@@ -124,7 +124,7 @@ class WorkerBase:
...
@@ -124,7 +124,7 @@ class WorkerBase:
def
execute_model
(
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
self
,
scheduler_output
:
SchedulerOutput
)
->
ModelRunnerOutput
|
None
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
None
:
"""If this method returns None, sample_tokens should be called immediately after
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
to obtain the ModelRunnerOutput.
...
@@ -362,7 +362,7 @@ class WorkerWrapperBase:
...
@@ -362,7 +362,7 @@ class WorkerWrapperBase:
scheduler_output
:
SchedulerOutput
,
scheduler_output
:
SchedulerOutput
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
ModelRunnerOutput
|
None
:
)
->
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
None
:
self
.
_apply_mm_cache
(
scheduler_output
)
self
.
_apply_mm_cache
(
scheduler_output
)
assert
self
.
worker
is
not
None
assert
self
.
worker
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment