Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8619e715
Unverified
Commit
8619e715
authored
Jun 24, 2025
by
Nick Hill
Committed by
GitHub
Jun 24, 2025
Browse files
[BugFix] Fix multi-node offline data parallel (#19937)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
c635c5f7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
4 deletions
+31
-4
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+6
-2
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+19
-1
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
8619e715
...
@@ -615,13 +615,16 @@ steps:
...
@@ -615,13 +615,16 @@ steps:
-
vllm/executor/
-
vllm/executor/
-
vllm/model_executor/models/
-
vllm/model_executor/models/
-
tests/distributed/
-
tests/distributed/
-
tests/examples/offline_inference/data_parallel.py
commands
:
commands
:
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
# the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
# the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
-
python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
-
label
:
Distributed Tests (2 GPUs)
# 40min
-
label
:
Distributed Tests (2 GPUs)
# 40min
mirror_hardwares
:
[
amdexperimental
]
mirror_hardwares
:
[
amdexperimental
]
...
...
vllm/entrypoints/llm.py
View file @
8619e715
...
@@ -1568,6 +1568,8 @@ class LLM:
...
@@ -1568,6 +1568,8 @@ class LLM:
pbar
.
update
(
n
)
pbar
.
update
(
n
)
else
:
else
:
pbar
.
update
(
1
)
pbar
.
update
(
1
)
if
pbar
.
n
==
num_requests
:
pbar
.
refresh
()
if
use_tqdm
:
if
use_tqdm
:
pbar
.
close
()
pbar
.
close
()
...
...
vllm/v1/engine/core.py
View file @
8619e715
...
@@ -877,12 +877,16 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -877,12 +877,16 @@ class DPEngineCoreProc(EngineCoreProc):
local_unfinished_reqs
)
local_unfinished_reqs
)
if
not
self
.
engines_running
:
if
not
self
.
engines_running
:
if
self
.
dp_rank
==
0
:
if
self
.
dp_rank
==
0
or
not
self
.
has_coordinator
:
# Notify client that we are pausing the loop.
# Notify client that we are pausing the loop.
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
logger
.
debug
(
"Wave %d finished, pausing engine loop."
,
self
.
current_wave
)
self
.
current_wave
)
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index
=
-
1
if
self
.
has_coordinator
else
0
self
.
output_queue
.
put_nowait
(
self
.
output_queue
.
put_nowait
(
(
-
1
,
(
client_index
,
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
)))
EngineCoreOutputs
(
wave_complete
=
self
.
current_wave
)))
self
.
current_wave
+=
1
self
.
current_wave
+=
1
...
...
vllm/v1/engine/core_client.py
View file @
8619e715
...
@@ -155,6 +155,11 @@ class EngineCoreClient(ABC):
...
@@ -155,6 +155,11 @@ class EngineCoreClient(ABC):
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
raise
NotImplementedError
raise
NotImplementedError
def
dp_engines_running
(
self
)
->
bool
:
"""Returns True id data parallel engines are collectively in a
running state."""
raise
NotImplementedError
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -282,6 +287,9 @@ class InprocClient(EngineCoreClient):
...
@@ -282,6 +287,9 @@ class InprocClient(EngineCoreClient):
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
_R
]:
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
return
self
.
engine_core
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
dp_engines_running
(
self
)
->
bool
:
return
False
@
dataclass
@
dataclass
class
BackgroundResources
:
class
BackgroundResources
:
...
@@ -384,6 +392,9 @@ class MPClient(EngineCoreClient):
...
@@ -384,6 +392,9 @@ class MPClient(EngineCoreClient):
dp_size
=
parallel_config
.
data_parallel_size
dp_size
=
parallel_config
.
data_parallel_size
dp_rank
=
parallel_config
.
data_parallel_rank
dp_rank
=
parallel_config
.
data_parallel_rank
# State used for data parallel.
self
.
engines_running
=
False
# SPMD mode is where there is an LLM instance per DP rank and
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
# examples/offline_inference/data_parallel.py.
...
@@ -539,6 +550,9 @@ class MPClient(EngineCoreClient):
...
@@ -539,6 +550,9 @@ class MPClient(EngineCoreClient):
while
self
.
pending_messages
and
self
.
pending_messages
[
-
1
][
0
].
done
:
while
self
.
pending_messages
and
self
.
pending_messages
[
-
1
][
0
].
done
:
self
.
pending_messages
.
pop
()
self
.
pending_messages
.
pop
()
def
dp_engines_running
(
self
)
->
bool
:
return
self
.
engines_running
def
_process_utility_output
(
output
:
UtilityOutput
,
def
_process_utility_output
(
output
:
UtilityOutput
,
utility_results
:
dict
[
int
,
AnyFuture
]):
utility_results
:
dict
[
int
,
AnyFuture
]):
...
@@ -562,6 +576,7 @@ class SyncMPClient(MPClient):
...
@@ -562,6 +576,7 @@ class SyncMPClient(MPClient):
log_stats
=
log_stats
,
log_stats
=
log_stats
,
)
)
self
.
is_dp
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
self
.
outputs_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
Exception
]]()
self
.
outputs_queue
=
queue
.
Queue
[
Union
[
EngineCoreOutputs
,
Exception
]]()
# Ensure that the outputs socket processing thread does not have
# Ensure that the outputs socket processing thread does not have
...
@@ -623,6 +638,8 @@ class SyncMPClient(MPClient):
...
@@ -623,6 +638,8 @@ class SyncMPClient(MPClient):
outputs
=
self
.
outputs_queue
.
get
()
outputs
=
self
.
outputs_queue
.
get
()
if
isinstance
(
outputs
,
Exception
):
if
isinstance
(
outputs
,
Exception
):
raise
self
.
_format_exception
(
outputs
)
from
None
raise
self
.
_format_exception
(
outputs
)
from
None
if
outputs
.
wave_complete
is
not
None
:
self
.
engines_running
=
False
return
outputs
return
outputs
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
...
@@ -650,6 +667,8 @@ class SyncMPClient(MPClient):
...
@@ -650,6 +667,8 @@ class SyncMPClient(MPClient):
return
future
.
result
()
return
future
.
result
()
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
if
self
.
is_dp
:
self
.
engines_running
=
True
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
)
self
.
_send_input
(
EngineCoreRequestType
.
ADD
,
request
)
def
abort_requests
(
self
,
request_ids
:
list
[
str
])
->
None
:
def
abort_requests
(
self
,
request_ids
:
list
[
str
])
->
None
:
...
@@ -911,7 +930,6 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -911,7 +930,6 @@ class DPAsyncMPClient(AsyncMPClient):
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
):
client_index
:
int
=
0
):
self
.
current_wave
=
0
self
.
current_wave
=
0
self
.
engines_running
=
False
# To route aborts to the correct engine.
# To route aborts to the correct engine.
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
...
...
vllm/v1/engine/llm_engine.py
View file @
8619e715
...
@@ -160,7 +160,7 @@ class LLMEngine:
...
@@ -160,7 +160,7 @@ class LLMEngine:
def
has_unfinished_requests
(
self
)
->
bool
:
def
has_unfinished_requests
(
self
)
->
bool
:
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
if
self
.
dp_group
is
None
:
if
self
.
dp_group
is
None
:
return
has_unfinished
return
has_unfinished
or
self
.
engine_core
.
dp_engines_running
()
return
self
.
has_unfinished_requests_dp
(
has_unfinished
)
return
self
.
has_unfinished_requests_dp
(
has_unfinished
)
def
has_unfinished_requests_dp
(
self
,
has_unfinished
:
bool
)
->
bool
:
def
has_unfinished_requests_dp
(
self
,
has_unfinished
:
bool
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment