Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
515080ad
Unverified
Commit
515080ad
authored
Jun 25, 2024
by
youkaichao
Committed by
GitHub
Jun 25, 2024
Browse files
[bugfix][distributed] fix shm broadcast when the queue size is full (#5801)
parent
3aa7b6cf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
46 deletions
+76
-46
tests/distributed/test_shm_broadcast.py
tests/distributed/test_shm_broadcast.py
+33
-16
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+43
-30
No files found.
tests/distributed/test_shm_broadcast.py
View file @
515080ad
import
multiprocessing
import
multiprocessing
import
random
import
random
import
time
import
time
from
typing
import
List
import
numpy
as
np
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
vllm.distributed.device_communicators.shm_broadcast
import
(
from
vllm.distributed.device_communicators.shm_broadcast
import
(
...
@@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
...
@@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
from
vllm.utils
import
update_environment_variables
from
vllm.utils
import
update_environment_variables
def
get_arrays
(
n
:
int
,
seed
:
int
=
0
)
->
List
[
np
.
ndarray
]:
np
.
random
.
seed
(
seed
)
sizes
=
np
.
random
.
randint
(
1
,
10_000
,
n
)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return
[
np
.
random
.
randint
(
1
,
100
,
i
)
for
i
in
sizes
]
def
distributed_run
(
fn
,
world_size
):
def
distributed_run
(
fn
,
world_size
):
number_of_processes
=
world_size
number_of_processes
=
world_size
processes
=
[]
processes
=
[]
...
@@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
...
@@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
def
worker_fn
():
def
worker_fn
():
writer_rank
=
2
writer_rank
=
2
broadcaster
=
ShmRingBufferIO
.
create_from_process_group
(
broadcaster
=
ShmRingBufferIO
.
create_from_process_group
(
dist
.
group
.
WORLD
,
1024
,
2
,
writer_rank
)
dist
.
group
.
WORLD
,
1024
*
1024
,
2
,
writer_rank
)
if
dist
.
get_rank
()
==
writer_rank
:
seed
=
random
.
randint
(
0
,
1000
)
dist
.
broadcast_object_list
([
seed
],
writer_rank
)
else
:
recv
=
[
None
]
dist
.
broadcast_object_list
(
recv
,
writer_rank
)
seed
=
recv
[
0
]
# type: ignore
dist
.
barrier
()
# in case we find a race condition
# print the seed so that we can reproduce the error
print
(
f
"Rank
{
dist
.
get_rank
()
}
got seed
{
seed
}
"
)
# test broadcasting with about 400MB of data
N
=
10_000
if
dist
.
get_rank
()
==
writer_rank
:
if
dist
.
get_rank
()
==
writer_rank
:
time
.
sleep
(
random
.
random
())
arrs
=
get_arrays
(
N
,
seed
)
broadcaster
.
broadcast_object
(
0
)
for
x
in
arrs
:
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
(
x
)
broadcaster
.
broadcast_object
({})
time
.
sleep
(
random
.
random
()
/
1000
)
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
([])
else
:
else
:
time
.
sleep
(
random
.
random
())
arrs
=
get_arrays
(
N
,
seed
)
a
=
broadcaster
.
broadcast_object
(
None
)
for
x
in
arrs
:
time
.
sleep
(
random
.
random
())
y
=
broadcaster
.
broadcast_object
(
None
)
b
=
broadcaster
.
broadcast_object
(
None
)
assert
np
.
array_equal
(
x
,
y
)
time
.
sleep
(
random
.
random
())
time
.
sleep
(
random
.
random
()
/
1000
)
c
=
broadcaster
.
broadcast_object
(
None
)
assert
a
==
0
assert
b
==
{}
assert
c
==
[]
dist
.
barrier
()
dist
.
barrier
()
...
...
vllm/distributed/device_communicators/shm_broadcast.py
View file @
515080ad
...
@@ -14,6 +14,12 @@ from vllm.logger import init_logger
...
@@ -14,6 +14,12 @@ from vllm.logger import init_logger
VLLM_RINGBUFFER_WARNING_INTERVAL
=
envs
.
VLLM_RINGBUFFER_WARNING_INTERVAL
VLLM_RINGBUFFER_WARNING_INTERVAL
=
envs
.
VLLM_RINGBUFFER_WARNING_INTERVAL
# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL
=
1e-7
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -145,8 +151,7 @@ class ShmRingBufferIO:
...
@@ -145,8 +151,7 @@ class ShmRingBufferIO:
@
contextmanager
@
contextmanager
def
acquire_write
(
self
):
def
acquire_write
(
self
):
assert
self
.
_is_writer
,
"Only writers can acquire write"
assert
self
.
_is_writer
,
"Only writers can acquire write"
start_index
=
self
.
current_idx
start_time
=
time
.
monotonic
()
start_time
=
time
.
time
()
n_warning
=
1
n_warning
=
1
while
True
:
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
...
@@ -154,19 +159,21 @@ class ShmRingBufferIO:
...
@@ -154,19 +159,21 @@ class ShmRingBufferIO:
written_flag
=
metadata_buffer
[
0
]
written_flag
=
metadata_buffer
[
0
]
if
written_flag
and
read_count
!=
self
.
buffer
.
n_reader
:
if
written_flag
and
read_count
!=
self
.
buffer
.
n_reader
:
# this block is written and not read by all readers
# this block is written and not read by all readers
# try to write to the next block
# for writers, `self.current_idx` is the next block to write
self
.
current_idx
=
(
self
.
current_idx
+
# if this block is not ready to write,
1
)
%
self
.
buffer
.
max_chunks
# we need to wait until it is read by all readers
if
self
.
current_idx
==
start_index
:
# no empty block found
# wait for a while
if
time
.
time
(
time
.
sleep
(
RINGBUFFER_SLEEP_INTERVAL
)
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
logger
.
warning
(
# if we wait for a long time, we should warn the user
"No available block found in %s second. "
,
if
time
.
monotonic
(
VLLM_RINGBUFFER_WARNING_INTERVAL
)
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
n_warning
+=
1
logger
.
warning
(
# wait for a while (0.1 us)
"No available block found in %s second. "
,
time
.
sleep
(
1e-7
)
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
continue
continue
# found a block that is either
# found a block that is either
# (1) not written
# (1) not written
...
@@ -188,13 +195,14 @@ class ShmRingBufferIO:
...
@@ -188,13 +195,14 @@ class ShmRingBufferIO:
metadata_buffer
[
i
]
=
0
metadata_buffer
[
i
]
=
0
# mark the block as written
# mark the block as written
metadata_buffer
[
0
]
=
1
metadata_buffer
[
0
]
=
1
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
break
break
@
contextmanager
@
contextmanager
def
acquire_read
(
self
):
def
acquire_read
(
self
):
assert
self
.
_is_reader
,
"Only readers can acquire read"
assert
self
.
_is_reader
,
"Only readers can acquire read"
start_index
=
self
.
current_idx
start_time
=
time
.
monotonic
()
start_time
=
time
.
time
()
n_warning
=
1
n_warning
=
1
while
True
:
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
...
@@ -204,19 +212,22 @@ class ShmRingBufferIO:
...
@@ -204,19 +212,22 @@ class ShmRingBufferIO:
# this block is either
# this block is either
# (1) not written
# (1) not written
# (2) already read by this reader
# (2) already read by this reader
# try to read the next block
self
.
current_idx
=
(
self
.
current_idx
+
# for readers, `self.current_idx` is the next block to read
1
)
%
self
.
buffer
.
max_chunks
# if this block is not ready,
if
self
.
current_idx
==
start_index
:
# we need to wait until it is written
# no block found
if
time
.
time
(
# wait for a while
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
time
.
sleep
(
RINGBUFFER_SLEEP_INTERVAL
)
logger
.
warning
(
"No available block found in %s second. "
,
# if we wait for a long time, we should warn the user
VLLM_RINGBUFFER_WARNING_INTERVAL
)
if
time
.
monotonic
(
n_warning
+=
1
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
# wait for a while (0.1 us)
logger
.
warning
(
time
.
sleep
(
1e-7
)
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
continue
continue
# found a block that is not read by this reader
# found a block that is not read by this reader
# let caller read from the buffer
# let caller read from the buffer
...
@@ -226,6 +237,8 @@ class ShmRingBufferIO:
...
@@ -226,6 +237,8 @@ class ShmRingBufferIO:
# caller has read from the buffer
# caller has read from the buffer
# set the read flag
# set the read flag
metadata_buffer
[
self
.
reader_rank
+
1
]
=
1
metadata_buffer
[
self
.
reader_rank
+
1
]
=
1
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
break
break
def
enqueue
(
self
,
obj
):
def
enqueue
(
self
,
obj
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment