Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
515080ad
Unverified
Commit
515080ad
authored
Jun 25, 2024
by
youkaichao
Committed by
GitHub
Jun 25, 2024
Browse files
[bugfix][distributed] fix shm broadcast when the queue size is full (#5801)
parent
3aa7b6cf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
46 deletions
+76
-46
tests/distributed/test_shm_broadcast.py
tests/distributed/test_shm_broadcast.py
+33
-16
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+43
-30
No files found.
tests/distributed/test_shm_broadcast.py
View file @
515080ad
import
multiprocessing
import
random
import
time
from
typing
import
List
import
numpy
as
np
import
torch.distributed
as
dist
from
vllm.distributed.device_communicators.shm_broadcast
import
(
...
...
@@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
from
vllm.utils
import
update_environment_variables
def
get_arrays
(
n
:
int
,
seed
:
int
=
0
)
->
List
[
np
.
ndarray
]:
np
.
random
.
seed
(
seed
)
sizes
=
np
.
random
.
randint
(
1
,
10_000
,
n
)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return
[
np
.
random
.
randint
(
1
,
100
,
i
)
for
i
in
sizes
]
def
distributed_run
(
fn
,
world_size
):
number_of_processes
=
world_size
processes
=
[]
...
...
@@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
def
worker_fn
():
writer_rank
=
2
broadcaster
=
ShmRingBufferIO
.
create_from_process_group
(
dist
.
group
.
WORLD
,
1024
,
2
,
writer_rank
)
dist
.
group
.
WORLD
,
1024
*
1024
,
2
,
writer_rank
)
if
dist
.
get_rank
()
==
writer_rank
:
seed
=
random
.
randint
(
0
,
1000
)
dist
.
broadcast_object_list
([
seed
],
writer_rank
)
else
:
recv
=
[
None
]
dist
.
broadcast_object_list
(
recv
,
writer_rank
)
seed
=
recv
[
0
]
# type: ignore
dist
.
barrier
()
# in case we find a race condition
# print the seed so that we can reproduce the error
print
(
f
"Rank
{
dist
.
get_rank
()
}
got seed
{
seed
}
"
)
# test broadcasting with about 400MB of data
N
=
10_000
if
dist
.
get_rank
()
==
writer_rank
:
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
(
0
)
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
({})
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
([])
arrs
=
get_arrays
(
N
,
seed
)
for
x
in
arrs
:
broadcaster
.
broadcast_object
(
x
)
time
.
sleep
(
random
.
random
()
/
1000
)
else
:
time
.
sleep
(
random
.
random
())
a
=
broadcaster
.
broadcast_object
(
None
)
time
.
sleep
(
random
.
random
())
b
=
broadcaster
.
broadcast_object
(
None
)
time
.
sleep
(
random
.
random
())
c
=
broadcaster
.
broadcast_object
(
None
)
assert
a
==
0
assert
b
==
{}
assert
c
==
[]
arrs
=
get_arrays
(
N
,
seed
)
for
x
in
arrs
:
y
=
broadcaster
.
broadcast_object
(
None
)
assert
np
.
array_equal
(
x
,
y
)
time
.
sleep
(
random
.
random
()
/
1000
)
dist
.
barrier
()
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
515080ad
...
...
@@ -14,6 +14,12 @@ from vllm.logger import init_logger
VLLM_RINGBUFFER_WARNING_INTERVAL
=
envs
.
VLLM_RINGBUFFER_WARNING_INTERVAL
# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL
=
1e-7
logger
=
init_logger
(
__name__
)
...
...
@@ -145,8 +151,7 @@ class ShmRingBufferIO:
@
contextmanager
def
acquire_write
(
self
):
assert
self
.
_is_writer
,
"Only writers can acquire write"
start_index
=
self
.
current_idx
start_time
=
time
.
time
()
start_time
=
time
.
monotonic
()
n_warning
=
1
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
...
...
@@ -154,19 +159,21 @@ class ShmRingBufferIO:
written_flag
=
metadata_buffer
[
0
]
if
written_flag
and
read_count
!=
self
.
buffer
.
n_reader
:
# this block is written and not read by all readers
# try to write to the next block
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
if
self
.
current_idx
==
start_index
:
# no empty block found
if
time
.
time
(
# for writers, `self.current_idx` is the next block to write
# if this block is not ready to write,
# we need to wait until it is read by all readers
# wait for a while
time
.
sleep
(
RINGBUFFER_SLEEP_INTERVAL
)
# if we wait for a long time, we should warn the user
if
time
.
monotonic
(
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
logger
.
warning
(
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
# wait for a while (0.1 us)
time
.
sleep
(
1e-7
)
continue
# found a block that is either
# (1) not written
...
...
@@ -188,13 +195,14 @@ class ShmRingBufferIO:
metadata_buffer
[
i
]
=
0
# mark the block as written
metadata_buffer
[
0
]
=
1
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
break
@
contextmanager
def
acquire_read
(
self
):
assert
self
.
_is_reader
,
"Only readers can acquire read"
start_index
=
self
.
current_idx
start_time
=
time
.
time
()
start_time
=
time
.
monotonic
()
n_warning
=
1
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
...
...
@@ -204,19 +212,22 @@ class ShmRingBufferIO:
# this block is either
# (1) not written
# (2) already read by this reader
# try to read the next block
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
if
self
.
current_idx
==
start_index
:
# no block found
if
time
.
time
(
# for readers, `self.current_idx` is the next block to read
# if this block is not ready,
# we need to wait until it is written
# wait for a while
time
.
sleep
(
RINGBUFFER_SLEEP_INTERVAL
)
# if we wait for a long time, we should warn the user
if
time
.
monotonic
(
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
logger
.
warning
(
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
# wait for a while (0.1 us)
time
.
sleep
(
1e-7
)
continue
# found a block that is not read by this reader
# let caller read from the buffer
...
...
@@ -226,6 +237,8 @@ class ShmRingBufferIO:
# caller has read from the buffer
# set the read flag
metadata_buffer
[
self
.
reader_rank
+
1
]
=
1
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
break
def
enqueue
(
self
,
obj
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment