Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc876d0f
Unverified
Commit
cc876d0f
authored
Jul 10, 2025
by
Or Ozeri
Committed by
GitHub
Jul 10, 2025
Browse files
[KVConnector] Aggregate finished requests on the scheduler (#19555)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
fdfd409f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
110 deletions
+139
-110
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+3
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+4
-61
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+105
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-40
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+21
-3
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
cc876d0f
...
...
@@ -190,7 +190,9 @@ class KVConnectorBase_V1(ABC):
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
finished generating tokens on the worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
cc876d0f
...
...
@@ -408,14 +408,6 @@ class NixlConnectorWorker:
# Track the expiration time of requests that are waiting to be sent.
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self
.
_done_recving_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
self
.
_done_sending_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
Optional
[
threading
.
Thread
]
=
None
# Background thread for initializing new NIXL handshakes.
...
...
@@ -830,15 +822,9 @@ class NixlConnectorWorker:
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
"""
Get requests that are done sending or recving.
In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
done_sending
=
self
.
_get_new_notifs
()
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
...
...
@@ -858,50 +844,7 @@ class NixlConnectorWorker:
del
self
.
_reqs_to_send
[
req_id
]
done_sending
.
add
(
req_id
)
if
self
.
world_size
==
1
:
return
done_sending
,
done_recving
# Rank 0: get finished from all other ranks.
if
self
.
tp_rank
==
0
:
for
req_id
in
done_sending
:
self
.
_done_sending_count
[
req_id
]
+=
1
for
req_id
in
done_recving
:
self
.
_done_recving_count
[
req_id
]
+=
1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids
:
list
[
str
]
=
[]
for
i
in
range
(
1
,
self
.
world_size
):
other_ranks_finished_ids
.
extend
(
self
.
tp_group
.
recv_object
(
src
=
i
))
for
req_id
in
other_ranks_finished_ids
:
if
(
req_id
in
self
.
_done_recving_count
or
req_id
in
self
.
_recving_transfers
):
self
.
_done_recving_count
[
req_id
]
+=
1
else
:
self
.
_done_sending_count
[
req_id
]
+=
1
# Return ids that finished on all ranks to the scheduler.
all_done_recving
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_recving_count
.
keys
()):
if
self
.
_done_recving_count
[
req_id
]
==
self
.
world_size
:
del
self
.
_done_recving_count
[
req_id
]
all_done_recving
.
add
(
req_id
)
all_done_sending
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_sending_count
.
keys
()):
if
self
.
_done_sending_count
[
req_id
]
>=
self
.
world_size
:
del
self
.
_done_sending_count
[
req_id
]
all_done_sending
.
add
(
req_id
)
return
all_done_sending
,
all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else
:
finished_req_ids
=
list
(
done_recving
.
union
(
done_sending
))
self
.
tp_group
.
send_object
(
finished_req_ids
,
dst
=
0
)
# Unused as only Rank 0 results are sent to scheduler.
return
done_sending
,
done_recving
return
done_sending
,
done_recving
def
_get_new_notifs
(
self
)
->
set
[
str
]:
"""
...
...
vllm/v1/executor/multiproc_executor.py
View file @
cc876d0f
...
...
@@ -9,7 +9,8 @@ import threading
import
time
import
traceback
import
weakref
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
collections
import
defaultdict
from
concurrent.futures
import
CancelledError
,
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
partial
...
...
@@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
if
self
.
max_concurrent_batches
>
1
:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self
.
io_thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
,
thread_name_prefix
=
"mp_exec_io"
)
self
.
output_rank
=
self
.
_get_output_rank
()
self
.
has_connector
=
self
.
vllm_config
.
kv_transfer_config
is
not
None
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self
.
_recv_remaining_count
=
defaultdict
[
str
,
int
](
lambda
:
self
.
world_size
)
self
.
_send_remaining_count
=
defaultdict
[
str
,
int
](
lambda
:
self
.
world_size
)
def
start_worker_monitor
(
self
):
workers
=
self
.
workers
...
...
@@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
self
,
scheduler_output
,
)
->
Union
[
ModelRunnerOutput
,
Future
[
ModelRunnerOutput
]]:
(
output
,
)
=
self
.
collective_rpc
(
non_block
=
self
.
max_concurrent_batches
>
1
if
not
self
.
has_connector
:
# get output only from a single worker (output_rank)
(
output
,
)
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
),
unique_reply_rank
=
self
.
output_rank
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
)
return
output
# get output from all workers
outputs
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
),
unique_reply_rank
=
self
.
output_rank
,
non_block
=
self
.
max_concurrent_batches
>
1
,
non_block
=
non_block
,
timeout
=
envs
.
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
)
return
output
# aggregate all workers output to a single output
if
non_block
:
return
self
.
_async_aggregate_workers_output
(
outputs
)
return
self
.
_aggregate_workers_output
(
outputs
)
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
...
...
@@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
except
TimeoutError
as
e
:
raise
TimeoutError
(
f
"RPC call to
{
method
}
timed out."
)
from
e
def
_aggregate_workers_output
(
self
,
outputs
:
list
[
ModelRunnerOutput
])
->
ModelRunnerOutput
:
# aggregate finished_sending, finished_recving from all workers
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
for
output
in
outputs
:
# update finished_sending
for
req_id
in
output
.
finished_sending
or
[]:
new_count
=
self
.
_send_remaining_count
[
req_id
]
-
1
if
new_count
==
0
:
# got response from all workers, report back to scheduler
finished_sending
.
add
(
req_id
)
del
self
.
_send_remaining_count
[
req_id
]
else
:
self
.
_send_remaining_count
[
req_id
]
=
new_count
# update finished_recving
for
req_id
in
output
.
finished_recving
or
[]:
new_count
=
self
.
_recv_remaining_count
[
req_id
]
-
1
if
new_count
==
0
:
# got response from all workers, report back to scheduler
finished_recving
.
add
(
req_id
)
del
self
.
_recv_remaining_count
[
req_id
]
else
:
self
.
_recv_remaining_count
[
req_id
]
=
new_count
# select output of the worker specified by output_rank
output
=
outputs
[
self
.
output_rank
]
# set the aggregated finished_sending / finished_recving
if
finished_sending
:
output
.
finished_sending
=
finished_sending
if
finished_recving
:
output
.
finished_recving
=
finished_recving
return
output
def
_async_aggregate_workers_output
(
self
,
output_futures
:
list
[
Future
[
ModelRunnerOutput
]]
)
->
(
Future
[
ModelRunnerOutput
]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future
:
Future
[
ModelRunnerOutput
]
=
Future
()
outputs
:
list
[
Optional
[
ModelRunnerOutput
]]
=
[
None
]
*
len
(
output_futures
)
def
make_callback
(
idx
):
def
callback
(
fut
):
if
result_future
.
done
():
return
try
:
outputs
[
idx
]
=
fut
.
result
()
except
CancelledError
:
result_future
.
cancel
()
except
Exception
as
e
:
result_future
.
set_exception
(
e
)
# this check assumes io_thread_pool uses a single thread
if
all
(
outputs
):
result_future
.
set_result
(
self
.
_aggregate_workers_output
(
cast
(
list
[
ModelRunnerOutput
],
outputs
)))
return
callback
for
i
,
output_future
in
enumerate
(
output_futures
):
output_future
.
add_done_callback
(
make_callback
(
i
))
return
result_future
@
staticmethod
def
_ensure_worker_termination
(
worker_procs
:
list
[
BaseProcess
]):
"""Ensure that all worker processes are terminated. Assumes workers have
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cc876d0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
gc
import
time
import
weakref
...
...
@@ -1234,8 +1233,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
:
torch
.
Tensor
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens_np
:
np
.
ndarray
,
finished_sending
:
Optional
[
set
[
str
]],
finished_recving
:
Optional
[
set
[
str
]],
)
->
ModelRunnerOutput
:
assert
self
.
input_batch
.
num_reqs
==
\
len
(
self
.
input_batch
.
pooling_params
),
\
...
...
@@ -1270,8 +1267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
pooler_output
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
)
@
torch
.
inference_mode
()
...
...
@@ -1282,11 +1277,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
# Return empty ModelRunnerOutput if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
has_kv_transfer_group
():
with
set_forward_context
(
None
,
self
.
vllm_config
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
return
self
.
kv_connector_no_forward
(
scheduler_output
)
# Return empty ModelRunnerOutput if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
# Prepare the decoder inputs.
(
attn_metadata
,
attention_cuda_graphs
,
logits_indices
,
...
...
@@ -1379,8 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
...
...
@@ -1406,8 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
if
self
.
input_batch
.
pooling_params
:
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
finished_sending
,
finished_recving
)
num_scheduled_tokens_np
)
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
...
@@ -1560,8 +1553,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
num_nans_in_logits
=
num_nans_in_logits
,
)
...
...
@@ -1686,22 +1677,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
def
kv_connector_no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
# KV send/recv even if no work to do.
with
set_forward_context
(
None
,
self
.
vllm_config
):
self
.
maybe_setup_kv_connector
(
scheduler_output
)
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
if
not
finished_sending
and
not
finished_recving
:
return
EMPTY_MODEL_RUNNER_OUTPUT
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
finished_sending
=
finished_sending
output
.
finished_recving
=
finished_recving
return
output
@
staticmethod
def
maybe_setup_kv_connector
(
scheduler_output
:
"SchedulerOutput"
):
# Update KVConnector with the KVConnector metadata forward().
...
...
@@ -1723,15 +1698,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
get_kv_transfer_group
().
wait_for_save
()
@
staticmethod
def
get_finished_kv_transfers
(
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
Optional
[
set
[
str
]],
Optional
[
set
[
str
]]]:
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_finished
(
scheduler_output
.
finished_req_ids
)
return
None
,
None
def
propose_ngram_draft_token_ids
(
self
,
sampled_token_ids
:
list
[
list
[
int
]],
...
...
vllm/v1/worker/gpu_worker.py
View file @
cc876d0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import
copy
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Optional
...
...
@@ -14,7 +15,9 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
,
set_custom_all_reduce
)
from
vllm.distributed.kv_transfer
import
ensure_kv_transfer_initialized
from
vllm.distributed.kv_transfer
import
(
ensure_kv_transfer_initialized
,
get_kv_transfer_group
,
has_kv_transfer_group
)
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
...
...
@@ -23,7 +26,7 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
ModelRunnerOutput
from
vllm.v1.utils
import
report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
...
...
@@ -316,14 +319,29 @@ class Worker(WorkerBase):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
parallel_config
=
self
.
vllm_config
.
parallel_config
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
and
not
get_pp_group
().
is_last_rank
:
assert
isinstance
(
output
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
all_gather_group
=
get_tp_group
())
return
None
output
=
EMPTY_MODEL_RUNNER_OUTPUT
assert
isinstance
(
output
,
ModelRunnerOutput
)
if
has_kv_transfer_group
():
finished_sending
,
finished_recving
=
(
get_kv_transfer_group
().
get_finished
(
scheduler_output
.
finished_req_ids
))
if
finished_sending
or
finished_recving
:
if
output
is
EMPTY_MODEL_RUNNER_OUTPUT
:
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
finished_sending
=
finished_sending
output
.
finished_recving
=
finished_recving
# with a connector, the scheduler expects output from all workers
return
output
# return output only from the driver worker
return
output
if
self
.
is_driver_worker
else
None
def
profile
(
self
,
is_start
:
bool
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment