Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e0fd34d
Unverified
Commit
6e0fd34d
authored
May 21, 2025
by
Russell Bryant
Committed by
GitHub
May 21, 2025
Browse files
[CI] Fix race condition with StatelessProcessGroup.barrier (#18506)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
176d62e4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
25 deletions
+157
-25
tests/distributed/test_shm_broadcast.py
tests/distributed/test_shm_broadcast.py
+5
-5
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+1
-17
vllm/distributed/utils.py
vllm/distributed/utils.py
+151
-3
No files found.
tests/distributed/test_shm_broadcast.py
View file @
6e0fd34d
...
@@ -9,7 +9,7 @@ import torch.distributed as dist
...
@@ -9,7 +9,7 @@ import torch.distributed as dist
from
vllm.distributed.device_communicators.shm_broadcast
import
MessageQueue
from
vllm.distributed.device_communicators.shm_broadcast
import
MessageQueue
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.utils
import
get_ip
,
get_open_port
,
update_environment_variables
from
vllm.utils
import
get_open_port
,
update_environment_variables
def
get_arrays
(
n
:
int
,
seed
:
int
=
0
)
->
list
[
np
.
ndarray
]:
def
get_arrays
(
n
:
int
,
seed
:
int
=
0
)
->
list
[
np
.
ndarray
]:
...
@@ -60,12 +60,12 @@ def worker_fn():
...
@@ -60,12 +60,12 @@ def worker_fn():
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
if
rank
==
0
:
if
rank
==
0
:
port
=
get_open_port
()
port
=
get_open_port
()
ip
=
get_ip
()
ip
=
'127.0.0.1'
dist
.
broadcast_object_list
([
ip
,
port
],
src
=
0
)
dist
.
broadcast_object_list
([
ip
,
port
],
src
=
0
)
else
:
else
:
recv
=
[
None
,
None
]
recv
=
[
None
,
None
]
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
ip
,
port
=
recv
ip
,
port
=
recv
# type: ignore
stateless_pg
=
StatelessProcessGroup
.
create
(
ip
,
port
,
rank
,
stateless_pg
=
StatelessProcessGroup
.
create
(
ip
,
port
,
rank
,
dist
.
get_world_size
())
dist
.
get_world_size
())
...
@@ -107,10 +107,10 @@ def worker_fn():
...
@@ -107,10 +107,10 @@ def worker_fn():
if
pg
==
dist
.
group
.
WORLD
:
if
pg
==
dist
.
group
.
WORLD
:
dist
.
barrier
()
dist
.
barrier
()
print
(
"torch distributed passed the test!"
)
print
(
f
"torch distributed passed the test!
Rank
{
rank
}
"
)
else
:
else
:
pg
.
barrier
()
pg
.
barrier
()
print
(
"StatelessProcessGroup passed the test!"
)
print
(
f
"StatelessProcessGroup passed the test!
Rank
{
rank
}
"
)
def
test_shm_broadcast
():
def
test_shm_broadcast
():
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
6e0fd34d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
pickle
import
pickle
import
sys
import
time
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
...
@@ -19,7 +17,7 @@ from zmq import IPV6 # type: ignore
...
@@ -19,7 +17,7 @@ from zmq import IPV6 # type: ignore
from
zmq
import
SUB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOSE
,
Context
# type: ignore
from
zmq
import
SUB
,
SUBSCRIBE
,
XPUB
,
XPUB_VERBOSE
,
Context
# type: ignore
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
StatelessProcessGroup
,
sched_yield
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_ip
,
get_open_port
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
get_ip
,
get_open_port
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
)
is_valid_ipv6_address
)
...
@@ -28,20 +26,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
...
@@ -28,20 +26,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD
=
((
sys
.
version_info
[:
3
]
>=
(
3
,
11
,
1
))
or
(
sys
.
version_info
[:
2
]
==
(
3
,
10
)
and
sys
.
version_info
[
2
]
>=
8
))
def
sched_yield
():
if
USE_SCHED_YIELD
:
os
.
sched_yield
()
else
:
time
.
sleep
(
0
)
class
ShmRingBuffer
:
class
ShmRingBuffer
:
...
...
vllm/distributed/utils.py
View file @
6e0fd34d
...
@@ -6,9 +6,12 @@
...
@@ -6,9 +6,12 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
dataclasses
import
dataclasses
import
datetime
import
datetime
import
os
import
pickle
import
pickle
import
socket
import
socket
import
sys
import
time
import
time
import
uuid
from
collections
import
deque
from
collections
import
deque
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
...
@@ -27,6 +30,20 @@ from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
...
@@ -27,6 +30,20 @@ from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD
=
((
sys
.
version_info
[:
3
]
>=
(
3
,
11
,
1
))
or
(
sys
.
version_info
[:
2
]
==
(
3
,
10
)
and
sys
.
version_info
[
2
]
>=
8
))
def
sched_yield
():
if
USE_SCHED_YIELD
:
os
.
sched_yield
()
else
:
time
.
sleep
(
0
)
def
ensure_divisibility
(
numerator
,
denominator
):
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
"""Ensure that numerator is divisible by the denominator."""
...
@@ -212,10 +229,141 @@ class StatelessProcessGroup:
...
@@ -212,10 +229,141 @@ class StatelessProcessGroup:
gathered_objs
.
append
(
recv_obj
)
gathered_objs
.
append
(
recv_obj
)
return
gathered_objs
return
gathered_objs
def
barrier
(
self
):
def
barrier
(
self
,
timeout
:
float
=
30.0
):
"""A barrier to synchronize all ranks."""
"""A robust barrier to synchronize all ranks.
Uses a multi-phase approach to ensure all processes reach the barrier
before proceeding:
1. Each process signals it has reached the barrier
2. Each process signals that it has confirmed the arrival of all other
ranks.
3. Rank 0 waits for all other ranks to signal their departure to ensure
that all ranks have departed the barrier first.
Args:
timeout: Maximum time in seconds to wait for each phase (in seconds)
Raises:
RuntimeError: If coordination fails or times out
"""
# Generate a barrier ID that is globally unique
try
:
if
self
.
rank
==
0
:
barrier_id
=
f
"barrier_
{
uuid
.
uuid4
()
}
"
self
.
broadcast_obj
(
barrier_id
,
src
=
0
)
else
:
barrier_id
=
self
.
broadcast_obj
(
None
,
src
=
0
)
except
Exception
as
e
:
raise
RuntimeError
(
"Failed to broadcast barrier_id"
)
from
e
# Phase 1: Signal arrival at barrier
# Wait for all processes to arrive
# We need all ranks to confirm the arrival of all other ranks.
# This is the key synchronization point.
arrival_key
=
f
"arrival_
{
barrier_id
}
_
{
self
.
rank
}
"
try
:
self
.
store
.
set
(
arrival_key
,
b
"1"
)
except
Exception
as
e
:
raise
RuntimeError
(
"Failed to signal barrier arrival"
)
from
e
start_time
=
time
.
time
()
processes_arrived
:
set
[
int
]
=
set
()
while
len
(
processes_arrived
)
<
self
.
world_size
:
# Check for timeout
cur_time
=
time
.
time
()
if
cur_time
-
start_time
>
timeout
:
raise
RuntimeError
(
"Barrier timed out after %f seconds"
,
timeout
)
# Check for each process
for
i
in
range
(
self
.
world_size
):
if
i
in
processes_arrived
:
continue
key
=
f
"arrival_
{
barrier_id
}
_
{
i
}
"
try
:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self
.
store
.
get
(
key
)
processes_arrived
.
add
(
i
)
except
KeyError
:
# Key doesn't exist yet
pass
except
Exception
as
check_e
:
logger
.
debug
(
"Error checking key existence: %s"
,
check_e
)
sched_yield
()
# Short sleep to avoid tight polling
if
len
(
processes_arrived
)
<
self
.
world_size
:
sched_yield
()
# Phase 2: Signal departure from barrier
# We only care to block at this stage in rank 0, which runs the
# server side of the TCPStore. We want to make sure that all
# clients have departed the barrier before rank 0 in case the
# next thing after the barrier is a shutdown, including tearing
# down the TCPStore. Other ranks can exit the barrier immediately
# after signaling their departure.
departure_key
=
f
"departure_
{
barrier_id
}
_
{
self
.
rank
}
"
try
:
self
.
store
.
set
(
departure_key
,
b
"1"
)
except
Exception
as
e
:
raise
RuntimeError
(
"Failed to signal barrier departure"
)
from
e
if
self
.
rank
!=
0
:
return
# Make rank 0 wait for all processes to signal departure
start_time
=
time
.
time
()
processes_departed
:
set
[
int
]
=
set
()
while
len
(
processes_departed
)
<
self
.
world_size
:
# Check for timeout
if
time
.
time
()
-
start_time
>
timeout
:
raise
RuntimeError
(
"Barrier departure timed out after %f s"
,
timeout
)
# Check for each process
for
i
in
range
(
self
.
world_size
):
if
i
in
processes_departed
:
continue
key
=
f
"departure_
{
barrier_id
}
_
{
i
}
"
try
:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self
.
store
.
get
(
key
)
processes_departed
.
add
(
i
)
except
KeyError
:
# Key doesn't exist yet
pass
except
Exception
as
check_e
:
logger
.
debug
(
"Error checking key existence: %s"
,
check_e
)
sched_yield
()
# Short sleep to avoid tight polling
if
len
(
processes_departed
)
<
self
.
world_size
:
sched_yield
()
# Clean up keys to avoid leaking memory in the store
for
i
in
range
(
self
.
world_size
):
for
i
in
range
(
self
.
world_size
):
self
.
broadcast_obj
(
None
,
src
=
i
)
try
:
self
.
store
.
delete_key
(
f
"arrival_
{
barrier_id
}
_
{
i
}
"
)
except
Exception
:
logger
.
debug
(
"Error deleting key: %s"
,
f
'arrival_
{
barrier_id
}
_
{
i
}
'
)
try
:
self
.
store
.
delete_key
(
f
"departure_
{
barrier_id
}
_
{
i
}
"
)
except
Exception
:
logger
.
debug
(
"Error deleting key: %s"
,
f
'departure_
{
barrier_id
}
_
{
i
}
'
)
@
staticmethod
@
staticmethod
def
create
(
def
create
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment