Unverified Commit 515080ad authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[bugfix][distributed] fix shm broadcast when the queue size is full (#5801)

parent 3aa7b6cf
import multiprocessing import multiprocessing
import random import random
import time import time
from typing import List
import numpy as np
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import ( from vllm.distributed.device_communicators.shm_broadcast import (
...@@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import ( ...@@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
np.random.seed(seed)
sizes = np.random.randint(1, 10_000, n)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return [np.random.randint(1, 100, i) for i in sizes]
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
number_of_processes = world_size number_of_processes = world_size
processes = [] processes = []
...@@ -47,24 +57,31 @@ def worker_fn_wrapper(fn): ...@@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
def worker_fn(): def worker_fn():
writer_rank = 2 writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group( broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank) dist.group.WORLD, 1024 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank: if dist.get_rank() == writer_rank:
time.sleep(random.random()) arrs = get_arrays(N, seed)
broadcaster.broadcast_object(0) for x in arrs:
time.sleep(random.random()) broadcaster.broadcast_object(x)
broadcaster.broadcast_object({}) time.sleep(random.random() / 1000)
time.sleep(random.random())
broadcaster.broadcast_object([])
else: else:
time.sleep(random.random()) arrs = get_arrays(N, seed)
a = broadcaster.broadcast_object(None) for x in arrs:
time.sleep(random.random()) y = broadcaster.broadcast_object(None)
b = broadcaster.broadcast_object(None) assert np.array_equal(x, y)
time.sleep(random.random()) time.sleep(random.random() / 1000)
c = broadcaster.broadcast_object(None)
assert a == 0
assert b == {}
assert c == []
dist.barrier() dist.barrier()
......
...@@ -14,6 +14,12 @@ from vllm.logger import init_logger ...@@ -14,6 +14,12 @@ from vllm.logger import init_logger
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -145,8 +151,7 @@ class ShmRingBufferIO: ...@@ -145,8 +151,7 @@ class ShmRingBufferIO:
@contextmanager @contextmanager
def acquire_write(self): def acquire_write(self):
assert self._is_writer, "Only writers can acquire write" assert self._is_writer, "Only writers can acquire write"
start_index = self.current_idx start_time = time.monotonic()
start_time = time.time()
n_warning = 1 n_warning = 1
while True: while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
...@@ -154,19 +159,21 @@ class ShmRingBufferIO: ...@@ -154,19 +159,21 @@ class ShmRingBufferIO:
written_flag = metadata_buffer[0] written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader: if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers # this block is written and not read by all readers
# try to write to the next block # for writers, `self.current_idx` is the next block to write
self.current_idx = (self.current_idx + # if this block is not ready to write,
1) % self.buffer.max_chunks # we need to wait until it is read by all readers
if self.current_idx == start_index:
# no empty block found # wait for a while
if time.time( time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning( logger.warning(
"No available block found in %s second. ", "No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL) VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1 n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
continue continue
# found a block that is either # found a block that is either
# (1) not written # (1) not written
...@@ -188,13 +195,14 @@ class ShmRingBufferIO: ...@@ -188,13 +195,14 @@ class ShmRingBufferIO:
metadata_buffer[i] = 0 metadata_buffer[i] = 0
# mark the block as written # mark the block as written
metadata_buffer[0] = 1 metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break break
@contextmanager @contextmanager
def acquire_read(self): def acquire_read(self):
assert self._is_reader, "Only readers can acquire read" assert self._is_reader, "Only readers can acquire read"
start_index = self.current_idx start_time = time.monotonic()
start_time = time.time()
n_warning = 1 n_warning = 1
while True: while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
...@@ -204,19 +212,22 @@ class ShmRingBufferIO: ...@@ -204,19 +212,22 @@ class ShmRingBufferIO:
# this block is either # this block is either
# (1) not written # (1) not written
# (2) already read by this reader # (2) already read by this reader
# try to read the next block
self.current_idx = (self.current_idx + # for readers, `self.current_idx` is the next block to read
1) % self.buffer.max_chunks # if this block is not ready,
if self.current_idx == start_index: # we need to wait until it is written
# no block found
if time.time( # wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning( logger.warning(
"No available block found in %s second. ", "No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL) VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1 n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
continue continue
# found a block that is not read by this reader # found a block that is not read by this reader
# let caller read from the buffer # let caller read from the buffer
...@@ -226,6 +237,8 @@ class ShmRingBufferIO: ...@@ -226,6 +237,8 @@ class ShmRingBufferIO:
# caller has read from the buffer # caller has read from the buffer
# set the read flag # set the read flag
metadata_buffer[self.reader_rank + 1] = 1 metadata_buffer[self.reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break break
def enqueue(self, obj): def enqueue(self, obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment