Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc876d0f
Unverified
Commit
cc876d0f
authored
Jul 10, 2025
by
Or Ozeri
Committed by
GitHub
Jul 10, 2025
Browse files
[KVConnector] Aggregate finished requests on the scheduler (#19555)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
fdfd409f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
110 deletions
+139
-110
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+3
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+4
-61
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+105
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-40
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+21
-3
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
cc876d0f
...
@@ -190,7 +190,9 @@ class KVConnectorBase_V1(ABC):
...
@@ -190,7 +190,9 @@ class KVConnectorBase_V1(ABC):
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
"""
"""
Notifies worker-side connector ids of requests that have
Notifies worker-side connector ids of requests that have
finished generating tokens.
finished generating tokens on the worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
Returns:
Returns:
ids of requests that have finished asynchronous transfer
ids of requests that have finished asynchronous transfer
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
cc876d0f
...
@@ -408,14 +408,6 @@ class NixlConnectorWorker:
...
@@ -408,14 +408,6 @@ class NixlConnectorWorker:
# Track the expiration time of requests that are waiting to be sent.
# Track the expiration time of requests that are waiting to be sent.
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self
.
_done_recving_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
self
.
_done_sending_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
# Background thread for handling new handshake requests.
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
Optional
[
threading
.
Thread
]
=
None
self
.
_nixl_handshake_listener_t
:
Optional
[
threading
.
Thread
]
=
None
# Background thread for initializing new NIXL handshakes.
# Background thread for initializing new NIXL handshakes.
...
@@ -830,15 +822,9 @@ class NixlConnectorWorker:
...
@@ -830,15 +822,9 @@ class NixlConnectorWorker:
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
"""
"""
Get requests that are done sending or recving.
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
In TP>1 setup, each rank exchanges KVs with its counterpart
to track which workers are done.
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
"""
"""
done_sending
=
self
.
_get_new_notifs
()
done_sending
=
self
.
_get_new_notifs
()
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
...
@@ -858,50 +844,7 @@ class NixlConnectorWorker:
...
@@ -858,50 +844,7 @@ class NixlConnectorWorker:
del
self
.
_reqs_to_send
[
req_id
]
del
self
.
_reqs_to_send
[
req_id
]
done_sending
.
add
(
req_id
)
done_sending
.
add
(
req_id
)
if
self
.
world_size
==
1
:
return
done_sending
,
done_recving
return
done_sending
,
done_recving
# Rank 0: get finished from all other ranks.
if
self
.
tp_rank
==
0
:
for
req_id
in
done_sending
:
self
.
_done_sending_count
[
req_id
]
+=
1
for
req_id
in
done_recving
:
self
.
_done_recving_count
[
req_id
]
+=
1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids
:
list
[
str
]
=
[]
for
i
in
range
(
1
,
self
.
world_size
):
other_ranks_finished_ids
.
extend
(
self
.
tp_group
.
recv_object
(
src
=
i
))
for
req_id
in
other_ranks_finished_ids
:
if
(
req_id
in
self
.
_done_recving_count
or
req_id
in
self
.
_recving_transfers
):
self
.
_done_recving_count
[
req_id
]
+=
1
else
:
self
.
_done_sending_count
[
req_id
]
+=
1
# Return ids that finished on all ranks to the scheduler.
all_done_recving
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_recving_count
.
keys
()):
if
self
.
_done_recving_count
[
req_id
]
==
self
.
world_size
:
del
self
.
_done_recving_count
[
req_id
]
all_done_recving
.
add
(
req_id
)
all_done_sending
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_sending_count
.
keys
()):
if
self
.
_done_sending_count
[
req_id
]
>=
self
.
world_size
:
del
self
.
_done_sending_count
[
req_id
]
all_done_sending
.
add
(
req_id
)
return
all_done_sending
,
all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else
:
finished_req_ids
=
list
(
done_recving
.
union
(
done_sending
))
self
.
tp_group
.
send_object
(
finished_req_ids
,
dst
=
0
)
# Unused as only Rank 0 results are sent to scheduler.
return
done_sending
,
done_recving
def
_get_new_notifs
(
self
)
->
set
[
str
]:
def
_get_new_notifs
(
self
)
->
set
[
str
]:
"""
"""
...
...
vllm/v1/executor/multiproc_executor.py
View file @
cc876d0f
...
@@ -9,7 +9,8 @@ import threading
...
@@ -9,7 +9,8 @@ import threading
import
time
import
time
import
traceback
import
traceback
import
weakref
import
weakref
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
collections
import
defaultdict
from
concurrent.futures
import
CancelledError
,
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
functools
import
partial
...
@@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
...
@@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
if
self
.
max_concurrent_batches
>
1
:
if
self
.
max_concurrent_batches
>
1
:
# Note: must use only 1 IO thread to keep dequeue sequence
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self
.
io_thread_pool
=
ThreadPoolExecutor
(
self
.
io_thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
,
thread_name_prefix
=
"mp_exec_io"
)
max_workers
=
1
,
thread_name_prefix
=
"mp_exec_io"
)
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self
.
_recv_remaining_count
=
defaultdict
[
str
,
int
](
lambda
:
self
.
world_size
)
self
.
_send_remaining_count
=
defaultdict
[
str
,
int
](
lambda
:
self
.
world_size
)
def
start_worker_monitor
(
self
):
def
start_worker_monitor
(
self
):
workers
=
self
.
workers
workers
=
self
.
workers
...
@@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
...
@@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
self
,
self
,
scheduler_output
,
scheduler_output
,
)
->
Union
[
ModelRunnerOutput
,
Future
[
ModelRunnerOutput
]]:
)
->
Union
[
ModelRunnerOutput
,
Future
[
ModelRunnerOutput
]]:
(
output
,
)
=
self
.
collective_rpc
(
non_block
=
self
.
max_concurrent_batches
>
1
if
not
self
.
has_connector
:
# get output only from a single worker (output_rank)
(
output
,
)
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
),
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
)
return
output
# get output from all workers
outputs
=
self
.
collective_rpc
(
"execute_model"
,
"execute_model"
,
args
=
(
scheduler_output
,
),
args
=
(
scheduler_output
,
),
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
non_block
=
self
.
max_concurrent_batches
>
1
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
)
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
)
return
output
# aggregate all workers output to a single output
if
non_block
:
return
self
.
_async_aggregate_workers_output
(
outputs
)
return
self
.
_aggregate_workers_output
(
outputs
)
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
method
:
Union
[
str
,
Callable
],
...
@@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
...
@@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
except
TimeoutError
as
e
:
except
TimeoutError
as
e
:
raise
TimeoutError
(
f
"RPC call to
{
method
}
timed out."
)
from
e
raise
TimeoutError
(
f
"RPC call to
{
method
}
timed out."
)
from
e
def
_aggregate_workers_output
(
self
,
outputs
:
list
[
ModelRunnerOutput
])
->
ModelRunnerOutput
:
# aggregate finished_sending, finished_recving from all workers
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
for
output
in
outputs
:
# update finished_sending
for
req_id
in
output
.
finished_sending
or
[]:
new_count
=
self
.
_send_remaining_count
[
req_id
]
-
1
if
new_count
==
0
:
# got response from all workers, report back to scheduler
finished_sending
.
add
(
req_id
)
del
self
.
_send_remaining_count
[
req_id
]
else
:
self
.
_send_remaining_count
[
req_id
]
=
new_count
# update finished_recving
for
req_id
in
output
.
finished_recving
or
[]:
new_count
=
self
.
_recv_remaining_count
[
req_id
]
-
1
if
new_count
==
0
:
# got response from all workers, report back to scheduler
finished_recving
.
add
(
req_id
)
del
self
.
_recv_remaining_count
[
req_id
]
else
:
self
.
_recv_remaining_count
[
req_id
]
=
new_count
# select output of the worker specified by output_rank
output
=
outputs
[
self
.
output_rank
]
# set the aggregated finished_sending / finished_recving
if
finished_sending
:
output
.
finished_sending
=
finished_sending
if
finished_recving
:
output
.
finished_recving
=
finished_recving
return
output
def
_async_aggregate_workers_output
(
self
,
output_futures
:
list
[
Future
[
ModelRunnerOutput
]]
)
->
(
Future
[
ModelRunnerOutput
]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future
:
Future
[
ModelRunnerOutput
]
=
Future
()
outputs
:
list
[
Optional
[
ModelRunnerOutput
]]
=
[
None
]
*
len
(
output_futures
)
def
make_callback
(
idx
):
def
callback
(
fut
):
if
result_future
.
done
():
return
try
:
outputs
[
idx
]
=
fut
.
result
()
except
CancelledError
:
result_future
.
cancel
()
except
Exception
as
e
:
result_future
.
set_exception
(
e
)
# this check assumes io_thread_pool uses a single thread
if
all
(
outputs
):
result_future
.
set_result
(
self
.
_aggregate_workers_output
(
cast
(
list
[
ModelRunnerOutput
],
outputs
)))
return
callback
for
i
,
output_future
in
enumerate
(
output_futures
):
output_future
.
add_done_callback
(
make_callback
(
i
))
return
result_future
@
staticmethod
@
staticmethod
def
_ensure_worker_termination
(
worker_procs
:
list
[
BaseProcess
]):
def
_ensure_worker_termination
(
worker_procs
:
list
[
BaseProcess
]):
"""Ensure that all worker processes are terminated. Assumes workers have
"""Ensure that all worker processes are terminated. Assumes workers have
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cc876d0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
gc
import
gc
import
time
import
time
import
weakref
import
weakref
...
@@ -1234,8 +1233,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1234,8 +1233,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens_np
:
np
.
ndarray
,
num_scheduled_tokens_np
:
np
.
ndarray
,
finished_sending
:
Optional
[
set
[
str
]],
finished_recving
:
Optional
[
set
[
str
]],
)
->
ModelRunnerOutput
:
)
->
ModelRunnerOutput
:
assert
self
.
input_batch
.
num_reqs
==
\
assert
self
.
input_batch
.
num_reqs
==
\
len
(
self
.
input_batch
.
pooling_params
),
\
len
(
self
.
input_batch
.
pooling_params
),
\
...
@@ -1270,8 +1267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1270,8 +1267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
pooler_output
=
pooler_output
,
pooler_output
=
pooler_output
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
)
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -1282,11 +1277,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1282,11 +1277,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
if
has_kv_transfer_group
():
# Return empty ModelRunnerOutput if there's no work to do.
with
set_forward_context
(
None
,
self
.
vllm_config
):
return
EMPTY_MODEL_RUNNER_OUTPUT
self
.
maybe_setup_kv_connector
(
scheduler_output
)
return
self
.
kv_connector_no_forward
(
scheduler_output
)
# Return empty ModelRunnerOutput if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
# Prepare the decoder inputs.
# Prepare the decoder inputs.
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
...
@@ -1379,8 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1379,8 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
self
.
maybe_wait_for_kv_save
()
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
hidden_states
,
aux_hidden_states
=
model_output
...
@@ -1406,8 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1406,8 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
if
self
.
input_batch
.
pooling_params
:
if
self
.
input_batch
.
pooling_params
:
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
finished_sending
,
num_scheduled_tokens_np
)
finished_recving
)
sample_hidden_states
=
hidden_states
[
logits_indices
]
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
@@ -1560,8 +1553,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1560,8 +1553,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs
=
logprobs_lists
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
num_nans_in_logits
=
num_nans_in_logits
,
)
)
...
@@ -1686,22 +1677,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1686,22 +1677,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids
=
draft_token_ids
.
tolist
()
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
def
kv_connector_no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
# KV send/recv even if no work to do.
with
set_forward_context
(
None
,
self
.
vllm_config
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
not
finished_sending
and
not
finished_recving
:
return
EMPTY_MODEL_RUNNER_OUTPUT
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
finished_sending
=
finished_sending
output
.
finished_recving
=
finished_recving
return
output
@
staticmethod
@
staticmethod
def
maybe_setup_kv_connector
(
scheduler_output
:
"SchedulerOutput"
):
def
maybe_setup_kv_connector
(
scheduler_output
:
"SchedulerOutput"
):
# Update KVConnector with the KVConnector metadata forward().
# Update KVConnector with the KVConnector metadata forward().
...
@@ -1723,15 +1698,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1723,15 +1698,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
wait_for_save
()
get_kv_transfer_group
().
wait_for_save
()
@
staticmethod
def
get_finished_kv_transfers
(
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_finished
(
scheduler_output
.
finished_req_ids
)
return
None
,
None
def
propose_ngram_draft_token_ids
(
def
propose_ngram_draft_token_ids
(
self
,
self
,
sampled_token_ids
:
list
[
list
[
int
]],
sampled_token_ids
:
list
[
list
[
int
]],
...
...
vllm/v1/worker/gpu_worker.py
View file @
cc876d0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
"""A GPU worker class."""
import
copy
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
...
@@ -14,7 +15,9 @@ from vllm.config import VllmConfig
...
@@ -14,7 +15,9 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
,
init_distributed_environment
,
set_custom_all_reduce
)
set_custom_all_reduce
)
from
vllm.distributed.kv_transfer
import
ensure_kv_transfer_initialized
from
vllm.distributed.kv_transfer
import
(
ensure_kv_transfer_initialized
,
get_kv_transfer_group
,
has_kv_transfer_group
)
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -23,7 +26,7 @@ from vllm.platforms import current_platform
...
@@ -23,7 +26,7 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
ModelRunnerOutput
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
WorkerBase
...
@@ -316,14 +319,29 @@ class Worker(WorkerBase):
...
@@ -316,14 +319,29 @@ class Worker(WorkerBase):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
intermediate_tensors
)
parallel_config
=
self
.
vllm_config
.
parallel_config
parallel_config
=
self
.
vllm_config
.
parallel_config
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
and
not
get_pp_group
().
is_last_rank
:
and
not
get_pp_group
().
is_last_rank
:
assert
isinstance
(
output
,
IntermediateTensors
)
assert
isinstance
(
output
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
all_gather_group
=
get_tp_group
())
all_gather_group
=
get_tp_group
())
return
None
output
=
EMPTY_MODEL_RUNNER_OUTPUT
assert
isinstance
(
output
,
ModelRunnerOutput
)
assert
isinstance
(
output
,
ModelRunnerOutput
)
if
has_kv_transfer_group
():
finished_sending
,
finished_recving
=
(
get_kv_transfer_group
().
get_finished
(
scheduler_output
.
finished_req_ids
))
if
finished_sending
or
finished_recving
:
if
output
is
EMPTY_MODEL_RUNNER_OUTPUT
:
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
finished_sending
=
finished_sending
output
.
finished_recving
=
finished_recving
# with a connector, the scheduler expects output from all workers
return
output
# return output only from the driver worker
return
output
if
self
.
is_driver_worker
else
None
return
output
if
self
.
is_driver_worker
else
None
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment