Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9a252bc
Unverified
Commit
d9a252bc
authored
Jun 20, 2024
by
youkaichao
Committed by
GitHub
Jun 21, 2024
Browse files
[Core][Distributed] add shm broadcast (#5399)
Co-authored-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
67005a07
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
384 additions
and
10 deletions
+384
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-1
tests/distributed/test_shm_broadcast.py
tests/distributed/test_shm_broadcast.py
+82
-0
vllm/distributed/device_communicators/shm_broadcast.py
vllm/distributed/device_communicators/shm_broadcast.py
+259
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+35
-9
vllm/envs.py
vllm/envs.py
+5
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
d9a252bc
...
@@ -28,9 +28,11 @@ steps:
...
@@ -28,9 +28,11 @@ steps:
-
label
:
Distributed Comm Ops Test
-
label
:
Distributed Comm Ops Test
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
command
:
pytest -v -s distributed/test_comm_ops.py
working_dir
:
"
/vllm-workspace/tests"
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
num_gpus
:
2
commands
:
-
pytest -v -s distributed/test_comm_ops.py
-
pytest -v -s distributed/test_shm_broadcast.py
-
label
:
Distributed Tests (2 GPUs)
-
label
:
Distributed Tests (2 GPUs)
mirror_hardwares
:
[
amd
]
mirror_hardwares
:
[
amd
]
...
...
tests/distributed/test_shm_broadcast.py
0 → 100644
View file @
d9a252bc
import
multiprocessing
import
random
import
time
import
torch.distributed
as
dist
from
vllm.distributed.device_communicators.shm_broadcast
import
(
ShmRingBuffer
,
ShmRingBufferIO
)
from
vllm.utils
import
update_environment_variables
def
distributed_run
(
fn
,
world_size
):
number_of_processes
=
world_size
processes
=
[]
for
i
in
range
(
number_of_processes
):
env
=
{}
env
[
'RANK'
]
=
str
(
i
)
env
[
'LOCAL_RANK'
]
=
str
(
i
)
env
[
'WORLD_SIZE'
]
=
str
(
number_of_processes
)
env
[
'LOCAL_WORLD_SIZE'
]
=
str
(
number_of_processes
)
env
[
'MASTER_ADDR'
]
=
'localhost'
env
[
'MASTER_PORT'
]
=
'12345'
p
=
multiprocessing
.
Process
(
target
=
fn
,
args
=
(
env
,
))
processes
.
append
(
p
)
p
.
start
()
for
p
in
processes
:
p
.
join
()
for
p
in
processes
:
assert
p
.
exitcode
==
0
def
worker_fn_wrapper
(
fn
):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def
wrapped_fn
(
env
):
update_environment_variables
(
env
)
dist
.
init_process_group
(
backend
=
"gloo"
)
fn
()
return
wrapped_fn
@
worker_fn_wrapper
def
worker_fn
():
writer_rank
=
2
broadcaster
=
ShmRingBufferIO
.
create_from_process_group
(
dist
.
group
.
WORLD
,
1024
,
2
,
writer_rank
)
if
dist
.
get_rank
()
==
writer_rank
:
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
(
0
)
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
({})
time
.
sleep
(
random
.
random
())
broadcaster
.
broadcast_object
([])
else
:
time
.
sleep
(
random
.
random
())
a
=
broadcaster
.
broadcast_object
(
None
)
time
.
sleep
(
random
.
random
())
b
=
broadcaster
.
broadcast_object
(
None
)
time
.
sleep
(
random
.
random
())
c
=
broadcaster
.
broadcast_object
(
None
)
assert
a
==
0
assert
b
==
{}
assert
c
==
[]
dist
.
barrier
()
def
test_shm_broadcast
():
distributed_run
(
worker_fn
,
4
)
def
test_singe_process
():
buffer
=
ShmRingBuffer
(
1
,
1024
,
4
)
reader
=
ShmRingBufferIO
(
buffer
,
reader_rank
=
0
)
writer
=
ShmRingBufferIO
(
buffer
,
reader_rank
=-
1
)
writer
.
enqueue
([
0
])
writer
.
enqueue
([
1
])
assert
reader
.
dequeue
()
==
[
0
]
assert
reader
.
dequeue
()
==
[
1
]
vllm/distributed/device_communicators/shm_broadcast.py
0 → 100644
View file @
d9a252bc
import
pickle
import
time
from
contextlib
import
contextmanager
from
multiprocessing
import
shared_memory
from
typing
import
Optional
from
unittest.mock
import
patch
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
VLLM_RINGBUFFER_WARNING_INTERVAL
=
envs
.
VLLM_RINGBUFFER_WARNING_INTERVAL
logger
=
init_logger
(
__name__
)
class
ShmRingBuffer
:
def
__init__
(
self
,
n_reader
:
int
,
max_chunk_bytes
:
int
,
max_chunks
:
int
,
name
:
Optional
[
str
]
=
None
):
"""
A shared memory ring buffer implementation for broadcast communication.
Essentially, it is a queue where only one will `enqueue` and multiple
will `dequeue`. The max size of each item, together with the max number
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
| (current_idx) | (current_idx)
v v
+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
metadata memory layout: each byte is a flag, the first byte is the written
flag, and the rest are reader flags. The flags are set to 0 by default.
+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+
During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""
# noqa
self
.
n_reader
=
n_reader
self
.
metadata_size
=
1
+
n_reader
self
.
max_chunk_bytes
=
max_chunk_bytes
self
.
max_chunks
=
max_chunks
self
.
total_bytes_of_buffer
=
(
self
.
max_chunk_bytes
+
self
.
metadata_size
)
*
self
.
max_chunks
self
.
data_offset
=
0
self
.
metadata_offset
=
self
.
max_chunk_bytes
*
self
.
max_chunks
if
name
is
None
:
# we are creating a buffer
self
.
is_creator
=
True
self
.
shared_memory
=
shared_memory
.
SharedMemory
(
create
=
True
,
size
=
self
.
total_bytes_of_buffer
)
# initialize the metadata section to 0
with
memoryview
(
self
.
shared_memory
.
buf
[
self
.
metadata_offset
:]
)
as
metadata_buffer
:
torch
.
frombuffer
(
metadata_buffer
,
dtype
=
torch
.
uint8
).
fill_
(
0
)
else
:
# we are opening an existing buffer
self
.
is_creator
=
False
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with
patch
(
"multiprocessing.resource_tracker.register"
,
lambda
*
args
,
**
kwargs
:
None
):
self
.
shared_memory
=
shared_memory
.
SharedMemory
(
name
=
name
)
assert
self
.
shared_memory
.
size
==
self
.
total_bytes_of_buffer
with
memoryview
(
self
.
shared_memory
.
buf
[
self
.
metadata_offset
:]
)
as
metadata_buffer
:
tensor
=
torch
.
frombuffer
(
metadata_buffer
,
dtype
=
torch
.
uint8
)
assert
torch
.
all
(
tensor
==
0
)
def
__reduce__
(
self
):
return
(
self
.
__class__
,
(
self
.
n_reader
,
self
.
max_chunk_bytes
,
self
.
max_chunks
,
self
.
shared_memory
.
name
),
)
def
__del__
(
self
):
self
.
shared_memory
.
close
()
if
self
.
is_creator
:
self
.
shared_memory
.
unlink
()
@
contextmanager
def
get_data
(
self
,
current_idx
:
int
):
start
=
self
.
data_offset
+
current_idx
*
self
.
max_chunk_bytes
end
=
start
+
self
.
max_chunk_bytes
with
memoryview
(
self
.
shared_memory
.
buf
[
start
:
end
])
as
buf
:
yield
buf
@
contextmanager
def
get_metadata
(
self
,
current_idx
:
int
):
start
=
self
.
metadata_offset
+
current_idx
*
self
.
metadata_size
end
=
start
+
self
.
metadata_size
with
memoryview
(
self
.
shared_memory
.
buf
[
start
:
end
])
as
buf
:
yield
buf
class
ShmRingBufferIO
:
def
__init__
(
self
,
buffer
:
ShmRingBuffer
,
reader_rank
:
int
):
self
.
buffer
=
buffer
self
.
reader_rank
=
reader_rank
self
.
_is_writer
=
self
.
reader_rank
==
-
1
self
.
_is_reader
=
not
self
.
_is_writer
if
self
.
_is_reader
:
assert
0
<=
self
.
reader_rank
<
buffer
.
n_reader
,
\
(
f
"Invalid reader rank
{
self
.
reader_rank
}
for buffer"
f
" created with
{
buffer
.
n_reader
}
readers"
)
self
.
current_idx
=
0
@
contextmanager
def
acquire_write
(
self
):
assert
self
.
_is_writer
,
"Only writers can acquire write"
start_index
=
self
.
current_idx
start_time
=
time
.
time
()
n_warning
=
1
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
read_count
=
sum
(
metadata_buffer
[
1
:])
written_flag
=
metadata_buffer
[
0
]
if
written_flag
and
read_count
!=
self
.
buffer
.
n_reader
:
# this block is written and not read by all readers
# try to write to the next block
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
if
self
.
current_idx
==
start_index
:
# no empty block found
if
time
.
time
(
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
logger
.
warning
(
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
# wait for a while (0.1 us)
time
.
sleep
(
1e-7
)
continue
# found a block that is either
# (1) not written
# (2) read by all readers
# mark the block as not written
metadata_buffer
[
0
]
=
0
# let caller write to the buffer
with
self
.
buffer
.
get_data
(
self
.
current_idx
)
as
buf
:
yield
buf
# caller has written to the buffer
# mark the block as written
metadata_buffer
[
0
]
=
1
for
i
in
range
(
1
,
self
.
buffer
.
n_reader
+
1
):
# set read flag to 0, meaning it is not read yet
metadata_buffer
[
i
]
=
0
break
@
contextmanager
def
acquire_read
(
self
):
assert
self
.
_is_reader
,
"Only readers can acquire read"
start_index
=
self
.
current_idx
start_time
=
time
.
time
()
n_warning
=
1
while
True
:
with
self
.
buffer
.
get_metadata
(
self
.
current_idx
)
as
metadata_buffer
:
read_flag
=
metadata_buffer
[
self
.
reader_rank
+
1
]
written_flag
=
metadata_buffer
[
0
]
if
not
written_flag
or
read_flag
:
# this block is either
# (1) not written
# (2) already read by this reader
# try to read the next block
self
.
current_idx
=
(
self
.
current_idx
+
1
)
%
self
.
buffer
.
max_chunks
if
self
.
current_idx
==
start_index
:
# no block found
if
time
.
time
(
)
-
start_time
>
VLLM_RINGBUFFER_WARNING_INTERVAL
*
n_warning
:
# noqa
logger
.
warning
(
"No available block found in %s second. "
,
VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning
+=
1
# wait for a while (0.1 us)
time
.
sleep
(
1e-7
)
continue
# found a block that is not read by this reader
# let caller read from the buffer
with
self
.
buffer
.
get_data
(
self
.
current_idx
)
as
buf
:
yield
buf
# caller has read from the buffer
# set the read flag
metadata_buffer
[
self
.
reader_rank
+
1
]
=
1
break
def
enqueue
(
self
,
obj
):
assert
self
.
_is_writer
,
"Only writers can enqueue"
serialized_obj
=
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
if
len
(
serialized_obj
)
>
self
.
buffer
.
max_chunk_bytes
:
raise
RuntimeError
(
f
"
{
len
(
serialized_obj
)
=
}
larger than the allowed value "
f
"
{
self
.
buffer
.
max_chunk_bytes
}
,"
"Please increase the max_chunk_bytes parameter."
)
with
self
.
acquire_write
()
as
buf
:
buf
[:
len
(
serialized_obj
)]
=
serialized_obj
def
dequeue
(
self
):
assert
self
.
_is_reader
,
"Only readers can dequeue"
with
self
.
acquire_read
()
as
buf
:
# no need to know the size of serialized object
# pickle format itself contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj
=
pickle
.
loads
(
buf
)
return
obj
def
broadcast_object
(
self
,
obj
=
None
):
if
self
.
_is_writer
:
self
.
enqueue
(
obj
)
return
obj
else
:
return
self
.
dequeue
()
def
create_from_process_group
(
pg
:
ProcessGroup
,
max_chunk_bytes
,
max_chunks
,
writer_rank
=
0
)
->
"ShmRingBufferIO"
:
group_rank
=
dist
.
get_rank
(
pg
)
group_world_size
=
dist
.
get_world_size
(
pg
)
ranks_inside_group
=
list
(
range
(
group_world_size
))
global_ranks
=
dist
.
get_process_group_ranks
(
pg
)
n_reader
=
group_world_size
-
1
buffer
:
ShmRingBuffer
if
group_rank
==
writer_rank
:
buffer
=
ShmRingBuffer
(
n_reader
,
max_chunk_bytes
,
max_chunks
)
dist
.
broadcast_object_list
([
buffer
],
src
=
global_ranks
[
writer_rank
])
dist
.
barrier
(
pg
)
return
ShmRingBufferIO
(
buffer
,
-
1
)
else
:
recv
=
[
None
]
dist
.
broadcast_object_list
(
recv
,
src
=
global_ranks
[
writer_rank
])
dist
.
barrier
(
pg
)
buffer
=
recv
[
0
]
# type: ignore
rest_ranks
=
[
r
for
r
in
ranks_inside_group
if
r
!=
writer_rank
]
return
ShmRingBufferIO
(
buffer
,
rest_ranks
.
index
(
group_rank
))
vllm/distributed/parallel_state.py
View file @
d9a252bc
...
@@ -98,6 +98,7 @@ class GroupCoordinator:
...
@@ -98,6 +98,7 @@ class GroupCoordinator:
# communicators are only created for world size > 1
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
shm_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -162,6 +163,13 @@ class GroupCoordinator:
...
@@ -162,6 +163,13 @@ class GroupCoordinator:
else
:
else
:
self
.
ca_comm
=
None
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.shm_broadcast
import
(
ShmRingBufferIO
)
self
.
shm_broadcaster
:
Optional
[
ShmRingBufferIO
]
=
None
if
self
.
world_size
>
1
and
is_in_the_same_node
(
self
.
cpu_group
):
self
.
shm_broadcaster
=
ShmRingBufferIO
.
create_from_process_group
(
self
.
cpu_group
,
1
<<
20
,
6
)
@
property
@
property
def
first_rank
(
self
):
def
first_rank
(
self
):
"""Return the global rank of the first process in the group"""
"""Return the global rank of the first process in the group"""
...
@@ -324,6 +332,30 @@ class GroupCoordinator:
...
@@ -324,6 +332,30 @@ class GroupCoordinator:
group
=
self
.
device_group
)
group
=
self
.
device_group
)
return
input_
return
input_
def
broadcast_object
(
self
,
obj
:
Optional
[
Any
]
=
None
,
src
:
int
=
0
):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
obj
if
self
.
shm_broadcaster
is
not
None
:
assert
src
==
0
,
"Shared memory broadcaster only supports src=0"
return
self
.
shm_broadcaster
.
broadcast_object
(
obj
)
if
self
.
rank_in_group
==
src
:
torch
.
distributed
.
broadcast_object_list
([
obj
],
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
return
obj
else
:
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
return
recv
[
0
]
def
broadcast_object_list
(
self
,
def
broadcast_object_list
(
self
,
obj_list
:
List
[
Any
],
obj_list
:
List
[
Any
],
src
:
int
=
0
,
src
:
int
=
0
,
...
@@ -371,9 +403,7 @@ class GroupCoordinator:
...
@@ -371,9 +403,7 @@ class GroupCoordinator:
# `metadata_list` lives in CPU memory.
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
# all happening on CPU. Therefore, we can use the CPU group.
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
self
.
broadcast_object
(
metadata_list
,
src
=
src
)
src
=
src
,
group
=
metadata_group
)
async_handles
=
[]
async_handles
=
[]
for
tensor
in
tensor_list
:
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
if
tensor
.
numel
()
==
0
:
...
@@ -396,14 +426,10 @@ class GroupCoordinator:
...
@@ -396,14 +426,10 @@ class GroupCoordinator:
async_handle
.
wait
()
async_handle
.
wait
()
else
:
else
:
recv_metadata_list
=
[
None
]
metadata_list
=
self
.
broadcast_object
(
None
,
src
=
src
)
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
group
=
metadata_group
)
assert
recv_metadata_list
[
0
]
is
not
None
tensor_dict
=
{}
tensor_dict
=
{}
async_handles
=
[]
async_handles
=
[]
for
key
,
value
in
recv_
metadata_list
[
0
]
:
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
dtype
=
value
.
dtype
,
...
...
vllm/envs.py
View file @
d9a252bc
...
@@ -5,6 +5,7 @@ if TYPE_CHECKING:
...
@@ -5,6 +5,7 @@ if TYPE_CHECKING:
VLLM_HOST_IP
:
str
=
""
VLLM_HOST_IP
:
str
=
""
VLLM_PORT
:
Optional
[
int
]
=
None
VLLM_PORT
:
Optional
[
int
]
=
None
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_USE_MODELSCOPE
:
bool
=
False
VLLM_RINGBUFFER_WARNING_INTERVAL
:
int
=
60
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
VLLM_INSTANCE_ID
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
...
@@ -114,6 +115,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -114,6 +115,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_INSTANCE_ID"
:
"VLLM_INSTANCE_ID"
:
lambda
:
os
.
environ
.
get
(
"VLLM_INSTANCE_ID"
,
None
),
lambda
:
os
.
environ
.
get
(
"VLLM_INSTANCE_ID"
,
None
),
# Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_RINGBUFFER_WARNING_INTERVAL"
,
"60"
)),
# path to cudatoolkit home directory, under which should be bin, include,
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
# and lib directories.
"CUDA_HOME"
:
"CUDA_HOME"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment