Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d524ce7
Unverified
Commit
8d524ce7
authored
Aug 02, 2025
by
Nick Hill
Committed by
GitHub
Aug 01, 2025
Browse files
[BugFix] Improve internal DP load balancing (#21617)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
9f9c38c3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
122 additions
and
59 deletions
+122
-59
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+3
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+4
-0
vllm/v1/engine/coordinator.py
vllm/v1/engine/coordinator.py
+73
-37
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+8
-5
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+29
-17
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+4
-0
vllm/v1/utils.py
vllm/v1/utils.py
+1
-0
No files found.
vllm/entrypoints/openai/api_server.py
View file @
8d524ce7
...
@@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args(
...
@@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args(
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.async_llm
import
AsyncLLM
async_llm
:
Optional
[
AsyncLLM
]
=
None
async_llm
:
Optional
[
AsyncLLM
]
=
None
client_count
=
client_config
.
pop
(
"client_count"
)
if
client_config
else
1
client_index
=
client_config
.
pop
(
client_index
=
client_config
.
pop
(
"client_index"
)
if
client_config
else
0
"client_index"
)
if
client_config
else
0
try
:
try
:
...
@@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args(
...
@@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args(
enable_log_requests
=
engine_args
.
enable_log_requests
,
enable_log_requests
=
engine_args
.
enable_log_requests
,
disable_log_stats
=
engine_args
.
disable_log_stats
,
disable_log_stats
=
engine_args
.
disable_log_stats
,
client_addresses
=
client_config
,
client_addresses
=
client_config
,
client_count
=
client_count
,
client_index
=
client_index
)
client_index
=
client_index
)
# Don't keep the dummy data in memory
# Don't keep the dummy data in memory
...
...
vllm/v1/engine/async_llm.py
View file @
8d524ce7
...
@@ -57,6 +57,7 @@ class AsyncLLM(EngineClient):
...
@@ -57,6 +57,7 @@ class AsyncLLM(EngineClient):
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_count
:
int
=
1
,
client_index
:
int
=
0
,
client_index
:
int
=
0
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -120,6 +121,7 @@ class AsyncLLM(EngineClient):
...
@@ -120,6 +121,7 @@ class AsyncLLM(EngineClient):
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_stats
=
self
.
log_stats
,
log_stats
=
self
.
log_stats
,
client_addresses
=
client_addresses
,
client_addresses
=
client_addresses
,
client_count
=
client_count
,
client_index
=
client_index
,
client_index
=
client_index
,
)
)
...
@@ -156,6 +158,7 @@ class AsyncLLM(EngineClient):
...
@@ -156,6 +158,7 @@ class AsyncLLM(EngineClient):
enable_log_requests
:
bool
=
False
,
enable_log_requests
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_count
:
int
=
1
,
client_index
:
int
=
0
,
client_index
:
int
=
0
,
disable_log_requests
:
bool
=
True
,
# Deprecated, will be removed
disable_log_requests
:
bool
=
True
,
# Deprecated, will be removed
)
->
"AsyncLLM"
:
)
->
"AsyncLLM"
:
...
@@ -176,6 +179,7 @@ class AsyncLLM(EngineClient):
...
@@ -176,6 +179,7 @@ class AsyncLLM(EngineClient):
log_stats
=
not
disable_log_stats
,
log_stats
=
not
disable_log_stats
,
usage_context
=
usage_context
,
usage_context
=
usage_context
,
client_addresses
=
client_addresses
,
client_addresses
=
client_addresses
,
client_count
=
client_count
,
client_index
=
client_index
,
client_index
=
client_index
,
)
)
...
...
vllm/v1/engine/coordinator.py
View file @
8d524ce7
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
multiprocessing
import
multiprocessing
import
time
import
time
import
weakref
import
weakref
...
@@ -65,18 +66,14 @@ class DPCoordinator:
...
@@ -65,18 +66,14 @@ class DPCoordinator:
# Assume coordinator is colocated with front-end procs when not in
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
# either external or hybrid DP LB mode.
local_only
=
not
(
external_lb
or
hybrid_lb
)
front_publish_address
=
get_engine_client_zmq_addr
(
front_publish_address
=
get_engine_client_zmq_addr
(
local_only
=
not
external_lb
and
not
hybrid_lb
,
host
=
host
)
local_only
=
local_only
,
host
=
host
)
local_only_eng
=
dp_size
==
parallel_config
.
data_parallel_size_local
local_only_eng
=
dp_size
==
parallel_config
.
data_parallel_size_local
back_publish_address
=
get_engine_client_zmq_addr
(
local_only_eng
,
host
)
back_publish_address
=
get_engine_client_zmq_addr
(
local_only_eng
,
host
)
back_output_address
=
get_engine_client_zmq_addr
(
local_only_eng
,
host
)
back_output_address
=
get_engine_client_zmq_addr
(
local_only_eng
,
host
)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms
=
0
if
external_lb
else
100
context
=
get_mp_context
()
context
=
get_mp_context
()
self
.
proc
:
multiprocessing
.
Process
=
context
.
Process
(
self
.
proc
:
multiprocessing
.
Process
=
context
.
Process
(
target
=
DPCoordinatorProc
.
run_coordinator
,
target
=
DPCoordinatorProc
.
run_coordinator
,
...
@@ -86,7 +83,6 @@ class DPCoordinator:
...
@@ -86,7 +83,6 @@ class DPCoordinator:
"front_publish_address"
:
front_publish_address
,
"front_publish_address"
:
front_publish_address
,
"back_output_address"
:
back_output_address
,
"back_output_address"
:
back_output_address
,
"back_publish_address"
:
back_publish_address
,
"back_publish_address"
:
back_publish_address
,
"min_stats_update_interval_ms"
:
min_stats_update_interval_ms
,
},
},
daemon
=
True
)
daemon
=
True
)
self
.
proc
.
start
()
self
.
proc
.
start
()
...
@@ -125,10 +121,6 @@ class DPCoordinatorProc:
...
@@ -125,10 +121,6 @@ class DPCoordinatorProc:
self
.
stats_update_interval_ms
=
min_stats_update_interval_ms
self
.
stats_update_interval_ms
=
min_stats_update_interval_ms
self
.
current_wave
=
0
self
.
engines_running
=
False
self
.
stats_changed
=
False
@
staticmethod
@
staticmethod
def
run_coordinator
(
def
run_coordinator
(
engine_count
:
int
,
engine_count
:
int
,
...
@@ -155,6 +147,16 @@ class DPCoordinatorProc:
...
@@ -155,6 +147,16 @@ class DPCoordinatorProc:
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
# For tracking request wave progression.
current_wave
=
0
engines_running
=
False
# For tracking request counts for internal load-balancing.
stats_changed
=
False
last_stats_step
=
-
1
last_stats_wave
=
-
1
last_step_counts
:
Optional
[
list
[
list
[
int
]]]
=
None
with
make_zmq_socket
(
with
make_zmq_socket
(
path
=
front_publish_address
,
# IPC
path
=
front_publish_address
,
# IPC
ctx
=
self
.
ctx
,
ctx
=
self
.
ctx
,
...
@@ -191,21 +193,33 @@ class DPCoordinatorProc:
...
@@ -191,21 +193,33 @@ class DPCoordinatorProc:
while
True
:
while
True
:
elapsed
=
int
(
time
.
time
()
*
1000
)
-
last_publish_time
elapsed
=
int
(
time
.
time
()
*
1000
)
-
last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every
4
seconds.
# changed, or otherwise every
5
seconds.
wait_for
=
(
self
.
stats_update_interval_ms
wait_for
=
(
self
.
stats_update_interval_ms
if
self
.
stats_changed
else
4000
)
if
stats_changed
else
5000
)
events
=
poller
.
poll
(
timeout
=
max
(
0
,
wait_for
-
elapsed
))
# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout
=
50
if
last_step_counts
is
None
else
0
events
=
poller
.
poll
(
timeout
=
max
(
min_timeout
,
wait_for
-
elapsed
))
if
not
events
:
if
not
events
:
# Poller timeout - publish current stats to front-ends.
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list
=
self
.
_get_engine_counts
()
if
last_step_counts
is
not
None
:
to_publish
=
(
engine_req_counts_list
,
self
.
current_wave
,
engine_req_counts_list
=
last_step_counts
self
.
engines_running
)
last_step_counts
=
None
else
:
engine_req_counts_list
=
self
.
_get_engine_counts
()
stats_changed
=
False
to_publish
=
(
engine_req_counts_list
,
current_wave
,
engines_running
)
publish_front
.
send
(
msgspec
.
msgpack
.
encode
(
to_publish
))
publish_front
.
send
(
msgspec
.
msgpack
.
encode
(
to_publish
))
last_publish_time
=
int
(
time
.
time
()
*
1000
)
last_publish_time
=
int
(
time
.
time
()
*
1000
)
self
.
stats_changed
=
False
continue
continue
events
=
dict
(
events
)
events
=
dict
(
events
)
wave_state_changed
=
False
if
publish_front
in
events
:
if
publish_front
in
events
:
buffer
=
publish_front
.
recv
()
buffer
=
publish_front
.
recv
()
...
@@ -232,7 +246,7 @@ class DPCoordinatorProc:
...
@@ -232,7 +246,7 @@ class DPCoordinatorProc:
# current_wave
# current_wave
# we note that 0 is the wave number for the new
# we note that 0 is the wave number for the new
# engine
# engine
self
.
engines_running
=
False
engines_running
=
False
logger
.
info
(
logger
.
info
(
"DPCoordinator scaled up from %s to %s "
"DPCoordinator scaled up from %s to %s "
"engines"
,
current_count
,
new_engine_count
)
"engines"
,
current_count
,
new_engine_count
)
...
@@ -248,15 +262,15 @@ class DPCoordinatorProc:
...
@@ -248,15 +262,15 @@ class DPCoordinatorProc:
# engines are paused, so that we can wake the other
# engines are paused, so that we can wake the other
# engines.
# engines.
engine_to_exclude
,
wave
=
decoded
engine_to_exclude
,
wave
=
decoded
if
not
self
.
engines_running
:
if
not
engines_running
:
if
wave
<
self
.
current_wave
:
if
wave
<
current_wave
:
# If the wave number is stale, ensure the message
# If the wave number is stale, ensure the message
# is handled by all the engines.
# is handled by all the engines.
engine_to_exclude
=
None
engine_to_exclude
=
None
self
.
engines_running
=
True
engines_running
=
True
self
.
stat
s
_changed
=
True
wave_
stat
e
_changed
=
True
self
.
_send_start_wave
(
publish_back
,
self
.
current_wave
,
self
.
_send_start_wave
(
publish_back
,
current_wave
,
engine_to_exclude
)
engine_to_exclude
)
if
output_back
in
events
:
if
output_back
in
events
:
...
@@ -274,36 +288,56 @@ class DPCoordinatorProc:
...
@@ -274,36 +288,56 @@ class DPCoordinatorProc:
# 1. Updated request load stats - update our local
# 1. Updated request load stats - update our local
# state with these.
# state with these.
stats
=
self
.
engines
[
eng_index
].
request_counts
stats
=
self
.
engines
[
eng_index
].
request_counts
stats_step
=
scheduler_stats
.
step_counter
stats_wave
=
scheduler_stats
.
current_wave
if
(
stats_wave
>
last_stats_wave
or
stats_wave
==
last_stats_wave
and
stats_step
>
last_stats_step
):
if
stats_changed
:
last_step_counts
=
self
.
_get_engine_counts
(
do_copy
=
True
)
last_stats_step
=
stats_step
last_stats_wave
=
stats_wave
elif
stats_wave
!=
last_stats_wave
or
(
stats_step
!=
last_stats_step
):
logger
.
warning
(
"Received stats for out-of-order "
"step (%d, %d) from engine %d (expected "
"> (%d, %d))"
,
stats_wave
,
stats_step
,
eng_index
,
last_stats_wave
,
last_stats_step
)
stats
[
0
]
=
scheduler_stats
.
num_waiting_reqs
stats
[
0
]
=
scheduler_stats
.
num_waiting_reqs
stats
[
1
]
=
scheduler_stats
.
num_running_reqs
stats
[
1
]
=
scheduler_stats
.
num_running_reqs
self
.
stats_changed
=
True
stats_changed
=
True
if
(
wave
:
=
outputs
.
wave_complete
)
is
not
None
:
if
(
wave
:
=
outputs
.
wave_complete
)
is
not
None
:
# 2. Notification from rank 0 engine that we've
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# moved into the global paused state
# (engines_running==False).
# (engines_running==False).
if
self
.
current_wave
<=
wave
:
if
current_wave
<=
wave
:
new_wave
=
wave
+
1
new_wave
=
wave
+
1
logger
.
debug
(
"Moving DP wave from %d to %d."
,
logger
.
debug
(
"Moving DP wave from %d to %d."
,
self
.
current_wave
,
new_wave
)
current_wave
,
new_wave
)
self
.
current_wave
=
new_wave
current_wave
=
new_wave
self
.
engines_running
=
False
engines_running
=
False
self
.
stat
s
_changed
=
True
wave_
stat
e
_changed
=
True
elif
(
wave
:
=
outputs
.
start_wave
)
is
not
None
and
(
elif
(
wave
:
=
outputs
.
start_wave
)
is
not
None
and
(
wave
>
self
.
current_wave
or
wave
>
current_wave
or
(
wave
==
self
.
current_wave
(
wave
==
current_wave
and
not
engines_running
)):
and
not
self
.
engines_running
)):
# 3. The engine received request for a non-current wave
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# so we must ensure that other engines progress to the
# next wave (race condition handling).
# next wave (race condition handling).
logger
.
debug
(
logger
.
debug
(
"Starting wave %d after notification of "
"Starting wave %d after notification of "
"stale wave request from engine."
,
wave
)
"stale wave request from engine."
,
wave
)
self
.
current_wave
=
wave
current_wave
=
wave
self
.
engines_running
=
True
engines_running
=
True
self
.
stat
s
_changed
=
True
wave_
stat
e
_changed
=
True
self
.
_send_start_wave
(
publish_back
,
wave
,
eng_index
)
self
.
_send_start_wave
(
publish_back
,
wave
,
eng_index
)
if
wave_state_changed
:
message
=
(
None
,
current_wave
,
engines_running
)
publish_front
.
send
(
msgspec
.
msgpack
.
encode
(
message
))
@
staticmethod
@
staticmethod
def
_send_start_wave
(
socket
:
zmq
.
Socket
,
wave
:
int
,
def
_send_start_wave
(
socket
:
zmq
.
Socket
,
wave
:
int
,
exclude_engine_index
:
Optional
[
int
]):
exclude_engine_index
:
Optional
[
int
]):
...
@@ -316,6 +350,8 @@ class DPCoordinatorProc:
...
@@ -316,6 +350,8 @@ class DPCoordinatorProc:
socket
.
send_multipart
(
socket
.
send_multipart
(
(
EngineCoreRequestType
.
START_DP_WAVE
.
value
,
wave_encoded
))
(
EngineCoreRequestType
.
START_DP_WAVE
.
value
,
wave_encoded
))
def
_get_engine_counts
(
self
)
->
list
[
list
[
int
]]:
def
_get_engine_counts
(
self
,
do_copy
=
False
)
->
list
[
list
[
int
]]:
"""Return list of [waiting, running] count lists for each engine."""
"""Return list of [waiting, running] count lists for each engine."""
if
do_copy
:
return
[
copy
.
copy
(
e
.
request_counts
)
for
e
in
self
.
engines
]
return
[
e
.
request_counts
for
e
in
self
.
engines
]
return
[
e
.
request_counts
for
e
in
self
.
engines
]
vllm/v1/engine/core.py
View file @
8d524ce7
...
@@ -928,7 +928,7 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -928,7 +928,7 @@ class DPEngineCoreProc(EngineCoreProc):
):
):
# Counts forward-passes of the model so that we can synchronize
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
# finished with DP peers every N steps.
self
.
counter
=
0
self
.
step_
counter
=
0
self
.
current_wave
=
0
self
.
current_wave
=
0
self
.
last_counts
=
(
0
,
0
)
self
.
last_counts
=
(
0
,
0
)
...
@@ -999,7 +999,9 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -999,7 +999,9 @@ class DPEngineCoreProc(EngineCoreProc):
counts
=
self
.
scheduler
.
get_request_counts
()
counts
=
self
.
scheduler
.
get_request_counts
()
if
counts
!=
self
.
last_counts
:
if
counts
!=
self
.
last_counts
:
self
.
last_counts
=
counts
self
.
last_counts
=
counts
stats
=
SchedulerStats
(
*
counts
)
stats
=
SchedulerStats
(
*
counts
,
step_counter
=
self
.
step_counter
,
current_wave
=
self
.
current_wave
)
self
.
output_queue
.
put_nowait
(
self
.
output_queue
.
put_nowait
(
(
-
1
,
EngineCoreOutputs
(
scheduler_stats
=
stats
)))
(
-
1
,
EngineCoreOutputs
(
scheduler_stats
=
stats
)))
...
@@ -1041,15 +1043,16 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -1041,15 +1043,16 @@ class DPEngineCoreProc(EngineCoreProc):
self
.
output_queue
.
put_nowait
(
self
.
output_queue
.
put_nowait
(
(
client_index
,
(
client_index
,
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
)))
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
)))
# Increment wave count and reset step counter.
self
.
current_wave
+=
1
self
.
current_wave
+=
1
self
.
step_counter
=
0
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
def
_has_global_unfinished_reqs
(
self
,
local_unfinished
:
bool
)
->
bool
:
# Optimization - only perform finish-sync all-reduce every 32 steps.
# Optimization - only perform finish-sync all-reduce every 32 steps.
self
.
counter
+=
1
self
.
step_
counter
+=
1
if
self
.
counter
!=
32
:
if
self
.
step_
counter
%
32
!=
0
:
return
True
return
True
self
.
counter
=
0
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
return
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
local_unfinished
)
local_unfinished
)
...
...
vllm/v1/engine/core_client.py
View file @
8d524ce7
...
@@ -86,11 +86,12 @@ class EngineCoreClient(ABC):
...
@@ -86,11 +86,12 @@ class EngineCoreClient(ABC):
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_count
:
int
=
1
,
client_index
:
int
=
0
,
client_index
:
int
=
0
,
)
->
"MPClient"
:
)
->
"MPClient"
:
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
client_args
=
(
vllm_config
,
executor_class
,
log_stats
,
client_args
=
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
client_addresses
,
client_count
,
client_index
)
if
parallel_config
.
data_parallel_size
>
1
:
if
parallel_config
.
data_parallel_size
>
1
:
if
parallel_config
.
data_parallel_external_lb
:
if
parallel_config
.
data_parallel_external_lb
:
# External load balancer - client per DP rank.
# External load balancer - client per DP rank.
...
@@ -727,6 +728,7 @@ class AsyncMPClient(MPClient):
...
@@ -727,6 +728,7 @@ class AsyncMPClient(MPClient):
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_count
:
int
=
1
,
client_index
:
int
=
0
):
client_index
:
int
=
0
):
super
().
__init__
(
super
().
__init__
(
asyncio_mode
=
True
,
asyncio_mode
=
True
,
...
@@ -929,11 +931,12 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -929,11 +931,12 @@ class DPAsyncMPClient(AsyncMPClient):
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_count
:
int
=
1
,
client_index
:
int
=
0
):
client_index
:
int
=
0
):
self
.
current_wave
=
0
self
.
current_wave
=
0
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
client_addresses
,
client_count
,
client_index
)
# List of [waiting, running] pair per engine.
# List of [waiting, running] pair per engine.
# Used only by DPLBAsyncMPClient subclass.
# Used only by DPLBAsyncMPClient subclass.
...
@@ -1029,7 +1032,11 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -1029,7 +1032,11 @@ class DPAsyncMPClient(AsyncMPClient):
counts
,
wave
,
running
=
msgspec
.
msgpack
.
decode
(
buf
)
counts
,
wave
,
running
=
msgspec
.
msgpack
.
decode
(
buf
)
self
.
current_wave
=
wave
self
.
current_wave
=
wave
self
.
engines_running
=
running
self
.
engines_running
=
running
self
.
lb_engines
=
counts
[
count_slice
]
if
counts
is
not
None
:
sliced_counts
=
counts
[
count_slice
]
self
.
lb_engines
=
sliced_counts
logger
.
debug
(
"Received counts: %s (%s)"
,
sliced_counts
,
count_slice
)
resources
.
stats_update_task
=
asyncio
.
create_task
(
resources
.
stats_update_task
=
asyncio
.
create_task
(
run_engine_stats_update_task
())
run_engine_stats_update_task
())
...
@@ -1065,40 +1072,45 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
...
@@ -1065,40 +1072,45 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
executor_class
:
type
[
Executor
],
executor_class
:
type
[
Executor
],
log_stats
:
bool
,
log_stats
:
bool
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_count
:
int
=
1
,
client_index
:
int
=
0
):
client_index
:
int
=
0
):
self
.
client_count
=
client_count
# To route aborts to the correct engine.
# To route aborts to the correct engine.
self
.
reqs_in_flight
:
dict
[
str
,
EngineIdentity
]
=
{}
self
.
reqs_in_flight
:
dict
[
str
,
EngineIdentity
]
=
{}
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
,
client_addresses
,
client_index
)
client_addresses
,
client_count
,
client_index
)
assert
len
(
self
.
core_engines
)
>
1
assert
len
(
self
.
core_engines
)
>
1
self
.
eng_start_index
=
(
len
(
self
.
core_engines
)
*
self
.
client_index
)
//
client_count
def
get_core_engine_for_request
(
def
get_core_engine_for_request
(
self
,
request
:
EngineCoreRequest
)
->
EngineIdentity
:
self
,
request
:
EngineCoreRequest
)
->
EngineIdentity
:
# Engines are in rank order.
# Engines are in rank order.
current_counts
=
self
.
lb_engines
if
(
eng_index
:
=
request
.
data_parallel_rank
)
is
None
:
if
(
eng_index
:
=
request
.
data_parallel_rank
)
is
None
:
if
not
self
.
lb_engine
s
:
if
not
current_count
s
:
return
self
.
core_engine
return
self
.
core_engine
# TODO use P2C alg for larger DP sizes
# TODO use P2C alg for larger DP sizes
num_engines
=
len
(
self
.
lb_engine
s
)
num_engines
=
len
(
current_count
s
)
min_co
unts
=
[
sys
.
maxsize
,
sys
.
maxsize
]
min_
s
co
re
=
sys
.
maxsize
eng_index
=
0
eng_index
=
0
for
i
in
range
(
num_engines
):
for
i
in
range
(
num_engines
):
# Start from client_index to help with balancing when engines
# Start from client_index to help with balancing when engines
# are empty.
# are empty.
idx
=
(
self
.
client_index
+
i
)
%
num_engines
idx
=
(
self
.
eng_start_index
+
i
)
%
num_engines
counts
=
self
.
lb_engines
[
idx
]
waiting
,
running
=
current_counts
[
idx
]
if
counts
<
min_counts
:
score
=
waiting
*
4
+
running
min_counts
=
counts
if
score
<
min_score
:
min_score
=
score
eng_index
=
idx
eng_index
=
idx
# Adjust local counts for better balancing between stats updates
# Increment local waiting count for better balancing between stats
# from the coordinator (which happen every 100ms).
# updates from the coordinator (which happen every 100ms).
if
min_counts
[
0
]:
current_counts
[
eng_index
][
0
]
+=
self
.
client_count
min_counts
[
0
]
+=
1
else
:
min_counts
[
1
]
+=
1
chosen_engine
=
self
.
core_engines
[
eng_index
]
chosen_engine
=
self
.
core_engines
[
eng_index
]
# Record which engine is chosen for this request, to handle aborts.
# Record which engine is chosen for this request, to handle aborts.
...
...
vllm/v1/metrics/stats.py
View file @
8d524ce7
...
@@ -33,6 +33,10 @@ class SchedulerStats:
...
@@ -33,6 +33,10 @@ class SchedulerStats:
num_running_reqs
:
int
=
0
num_running_reqs
:
int
=
0
num_waiting_reqs
:
int
=
0
num_waiting_reqs
:
int
=
0
# These are used for internal DP load-balancing.
step_counter
:
int
=
0
current_wave
:
int
=
0
kv_cache_usage
:
float
=
0.0
kv_cache_usage
:
float
=
0.0
prefix_cache_stats
:
PrefixCacheStats
=
field
(
prefix_cache_stats
:
PrefixCacheStats
=
field
(
...
...
vllm/v1/utils.py
View file @
8d524ce7
...
@@ -154,6 +154,7 @@ class APIServerProcessManager:
...
@@ -154,6 +154,7 @@ class APIServerProcessManager:
client_config
=
{
client_config
=
{
"input_address"
:
in_addr
,
"input_address"
:
in_addr
,
"output_address"
:
out_addr
,
"output_address"
:
out_addr
,
"client_count"
:
num_servers
,
"client_index"
:
i
"client_index"
:
i
}
}
if
stats_update_address
is
not
None
:
if
stats_update_address
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment