Unverified Commit 6f0dd938 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[Core] Remove busy loop from idle buffer readers (#28053)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 5d199ac8
...@@ -124,8 +124,6 @@ def test_models( ...@@ -124,8 +124,6 @@ def test_models(
[ [
("facebook/opt-125m", "ray", "", "L4", {}), ("facebook/opt-125m", "ray", "", "L4", {}),
("facebook/opt-125m", "mp", "", "L4", {}), ("facebook/opt-125m", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "A100", {}), ("facebook/opt-125m", "ray", "", "A100", {}),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import random import random
import threading
import time import time
from unittest import mock
import multiprocess as mp
import numpy as np import numpy as np
import pytest
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
...@@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: ...@@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
return [np.random.randint(1, 100, i) for i in sizes] return [np.random.randint(1, 100, i) for i in sizes]
def distributed_run(fn, world_size): def distributed_run(fn, world_size, timeout=60):
"""Run a function in multiple processes with proper error handling.
Args:
fn: Function to run in each process
world_size: Number of processes to spawn
timeout: Maximum time in seconds to wait for processes (default: 60)
"""
number_of_processes = world_size number_of_processes = world_size
processes = [] processes = []
for i in range(number_of_processes): for i in range(number_of_processes):
...@@ -33,19 +43,45 @@ def distributed_run(fn, world_size): ...@@ -33,19 +43,45 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,)) p = mp.Process(target=fn, args=(env,))
processes.append(p) processes.append(p)
p.start() p.start()
for p in processes: # Monitor processes and fail fast if any process fails
start_time = time.time()
failed_processes = []
# Wait for all processes, checking for failures
while time.time() - start_time < timeout:
all_done = True
for i, p in enumerate(processes):
if p.is_alive():
all_done = False
elif p.exitcode != 0:
# Process failed
failed_processes.append((i, p.exitcode))
break
if failed_processes or all_done:
break
time.sleep(0.1) # Check every 100ms
# Check for timeout if no failures detected yet
for i, p in enumerate(processes):
if p.is_alive():
p.kill()
p.join() p.join()
for p in processes: # Report failures
assert p.exitcode == 0 if failed_processes:
error_msg = "Distributed test failed:\n"
for rank, status in failed_processes:
error_msg += f" Rank {rank}: Exit code {status}\n"
raise AssertionError(error_msg)
def worker_fn_wrapper(fn): def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly # `mp.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments # so we need to pass the environment variables as arguments
# and update the environment variables in the function # and update the environment variables in the function
def wrapped_fn(env): def wrapped_fn(env):
...@@ -115,3 +151,244 @@ def worker_fn(): ...@@ -115,3 +151,244 @@ def worker_fn():
def test_shm_broadcast(): def test_shm_broadcast():
distributed_run(worker_fn, 4) distributed_run(worker_fn, 4)
@worker_fn_wrapper
def worker_fn_test_shutdown_busy():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
if not message_queue._is_writer:
# Put into busy mode
message_queue._spin_condition.busy_loop_s = 9999
shutdown_event = threading.Event()
def shutdown_thread(mq, shutdown_event):
shutdown_event.wait()
mq.shutdown()
threading.Thread(
target=shutdown_thread, args=(message_queue, shutdown_event)
).start()
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
shutdown_event.set()
with pytest.raises(RuntimeError, match="cancelled"):
message_queue.dequeue(timeout=1)
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
dist.barrier()
def test_message_queue_shutdown_busy(caplog_vllm):
distributed_run(worker_fn_test_shutdown_busy, 4)
print(caplog_vllm.text)
@worker_fn_wrapper
def worker_fn_test_shutdown_idle():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
if not message_queue._is_writer:
# Put into idle mode
message_queue._spin_condition.last_read = 0
shutdown_event = threading.Event()
def shutdown_thread(mq, shutdown_event):
shutdown_event.wait()
mq.shutdown()
threading.Thread(
target=shutdown_thread, args=(message_queue, shutdown_event)
).start()
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
shutdown_event.set()
with pytest.raises(RuntimeError, match="cancelled"):
message_queue.dequeue(timeout=1)
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
dist.barrier()
def test_message_queue_shutdown_idle():
distributed_run(worker_fn_test_shutdown_idle, 4)
@worker_fn_wrapper
def worker_fn_test_idle_to_busy():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
message1 = "hello world"
message2 = np.random.randint(1, 100, 100)
with mock.patch.object(
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
) as wrapped_wait:
if not message_queue._is_writer:
# Put into idle mode
message_queue._spin_condition.last_read = 0
# no messages, so expect a TimeoutError
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
# wait should only be called once while idle
assert wrapped_wait.call_count == 1
# sync with the writer and wait for message1
dist.barrier()
recv_message = message_queue.dequeue(timeout=5)
assert recv_message == message1
# second call to wait, with a message read, this puts in a busy spin
assert wrapped_wait.call_count == 2
# sync with the writer and wait for message2
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert np.array_equal(recv_message, message2)
# in busy mode, we expect wait to have been called multiple times
assert wrapped_wait.call_count > 3
else:
# writer writes two messages in sync with the reader
dist.barrier()
# sleep delays the send to ensure reader enters the read loop
time.sleep(0.1)
message_queue.enqueue(message1)
dist.barrier()
time.sleep(0.1)
message_queue.enqueue(message2)
message_queue.shutdown()
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
def test_message_queue_idle_wake():
distributed_run(worker_fn_test_idle_to_busy, 4)
@worker_fn_wrapper
def worker_fn_test_busy_to_idle():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
message1 = 12345
message2 = list(range(3))
with mock.patch.object(
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
) as wrapped_wait:
if not message_queue._is_writer:
# Put into busy mode
message_queue._spin_condition.busy_loop_s = 9999
# sync with the writer and wait for message1
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert recv_message == message1
# in busy mode, we expect wait to have been called many times
assert wrapped_wait.call_count > 1
# simulate busy loop ending
message_queue._spin_condition.busy_loop_s = 0
# ensure we enter idle mode, then record call count
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
call_count = wrapped_wait.call_count
# sync with the writer and wait for message2
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert recv_message == message2
# call to wait after idle should only happen once
assert wrapped_wait.call_count == call_count + 1
else:
# writer writes two messages in sync with the reader
dist.barrier()
# sleep delays the send to ensure reader enters the read loop
time.sleep(0.1)
message_queue.enqueue(message1)
dist.barrier()
time.sleep(0.1)
message_queue.enqueue(message2)
message_queue.shutdown()
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
def test_message_queue_busy_to_idle():
distributed_run(worker_fn_test_busy_to_idle, 4)
def test_warning_logs(caplog_vllm):
"""
Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
when indefinite=False, and are not emitted when indefinite=True.
"""
# Patch the warning log interval to every 1 ms during reads
with mock.patch(
"vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
new=0.001, # 1 ms
):
writer = MessageQueue(
n_reader=1,
n_local_reader=1,
max_chunk_bytes=1024 * 1024, # 1MB chunks
max_chunks=10,
)
reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
writer.wait_until_ready()
reader.wait_until_ready()
# We should have at least one warning log here
# "0 seconds" expected due to rounding of 1ms test interval
with pytest.raises(TimeoutError):
reader.dequeue(timeout=0.01, indefinite=False)
assert any(
"No available shared memory broadcast block found in 0 seconds"
in record.message
for record in caplog_vllm.records
)
caplog_vllm.clear()
# We should have no warnings this time
with pytest.raises(TimeoutError):
reader.dequeue(timeout=0.01, indefinite=True)
assert all(
"No available shared memory broadcast block found in 0 seconds"
not in record.message
for record in caplog_vllm.records
)
# Clean up when done
writer.shutdown()
reader.shutdown()
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import pickle import pickle
import sys
import threading import threading
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
from pickle import PickleBuffer from pickle import PickleBuffer
from threading import Event
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch from unittest.mock import patch
...@@ -18,6 +18,7 @@ import zmq ...@@ -18,6 +18,7 @@ import zmq
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from zmq import ( # type: ignore from zmq import ( # type: ignore
IPV6, # type: ignore IPV6, # type: ignore
PUB,
SUB, SUB,
SUBSCRIBE, SUBSCRIBE,
XPUB, XPUB,
...@@ -32,6 +33,7 @@ from vllm.platforms import current_platform ...@@ -32,6 +33,7 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import ( from vllm.utils.network_utils import (
get_ip, get_ip,
get_open_port, get_open_port,
get_open_zmq_inproc_path,
get_open_zmq_ipc_path, get_open_zmq_ipc_path,
is_valid_ipv6_address, is_valid_ipv6_address,
) )
...@@ -78,50 +80,125 @@ def to_bytes_big(value: int, size: int) -> bytes: ...@@ -78,50 +80,125 @@ def to_bytes_big(value: int, size: int) -> bytes:
logger = init_logger(__name__) logger = init_logger(__name__)
def long_wait_time_msg(threshold: int) -> str: LONG_WAIT_TIME_LOG_MSG = (
return (
"No available shared memory broadcast block found " "No available shared memory broadcast block found "
f"in {threshold} seconds. This typically happens " "in %d seconds. This typically happens "
"when some processes are hanging or doing some " "when some processes are hanging or doing some "
"time-consuming work (e.g. compilation, " "time-consuming work (e.g. compilation, "
"weight/kv cache quantization)." "weight/kv cache quantization)."
) )
class SpinTimer:
def record_activity(self):
pass
def spin(self):
sched_yield()
class SpinSleepTimer(SpinTimer): class SpinCondition:
""" """
In setups which have long inactivity periods it is desirable to reduce This class implements an interface similar to a threading.Condition. It
system power consumption when vllm does nothing. This would lead to more allows a writer to notify readers to wake up and read from the shared memory
CPU thermal headroom when a request eventually comes, especially when buffer. This notification is done over a zmq socket.
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage. For optimal performance under load we don't want the readers to need to poll
the zmq socket for every read. So the `wait` method here will return
The simplest solution is to reduce polling frequency when there is no immediately when reads are frequent, and will only enter "idle mode" and
activity for a certain period of time. await a notification on the zmq socket after a period of inactivity. This
allows the readers to spin quickly, hence "SpinCondition".
To support clean shutdown, a separate thread in the reader's process must be
able to wake the reader so that it can exit. A separate cancel() method is
implemented with an in-process socket to allow this interruption.
""" """
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): def __init__(
self.last_activity = time.monotonic() self,
self.busy_loop_s = busy_loop_s is_reader: bool,
self.wait_sleep_s = wait_sleep_s context: zmq.Context,
notify_address: str,
busy_loop_s: float = 1,
):
self.is_reader = is_reader
def record_activity(self): if is_reader:
self.last_activity = time.monotonic() # Time of last shm buffer read
self.last_read = time.monotonic()
def spin(self): # Time to keep busy-looping on the shm buffer before going idle
curr_time = time.monotonic() self.busy_loop_s = busy_loop_s
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s) # Readers subscribe to write notifications
self.local_notify_socket: zmq.Socket = context.socket(SUB)
# Set zmq.CONFLATE to only keep the last message that the socket
# receives. This prevents us from piling up notification messages
# under high load when we aren't polling the socket.
self.local_notify_socket.setsockopt(zmq.CONFLATE, 1)
# Subscribe to all messages on the socket
self.local_notify_socket.setsockopt_string(SUBSCRIBE, "")
self.local_notify_socket.connect(notify_address)
# Readers require a process-local socket to poll for cancellation
cancel_path = get_open_zmq_inproc_path()
self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
self.write_cancel_socket.bind(cancel_path)
self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
self.read_cancel_socket.connect(cancel_path)
# Poller allows waiting on either `.notify()` or `.cancel()`
self.poller = zmq.Poller()
self.poller.register(self.read_cancel_socket, zmq.POLLIN)
self.poller.register(self.local_notify_socket, zmq.POLLIN)
else: else:
# Writer side publishes write notifications
self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore
# Set high water mark to 1 - we don't need to send a massive amount of
# pings during busy operation. PUB sockets will silently drop subsequent
# messages after the high water mark is reached.
self.local_notify_socket.setsockopt(zmq.SNDHWM, 1)
self.local_notify_socket.bind(notify_address)
self.last_read = 0
self.busy_loop_s = 0
self.read_cancel_socket = None
self.write_cancel_socket = None
self.poller = None
def record_read(self):
self.last_read = time.monotonic()
def cancel(self):
# Sends cancellation ping that will cause the reader to wake up.
# This is done from a monitor thread in the same process as the reader.
if self.is_reader:
logger.debug("Canceling waiting reads on SHM Buffer")
self.write_cancel_socket.send(b"\x00")
def wait(self, timeout_ms: int | None = None) -> None:
"""Wait for data on the shared memory buffer.
Yields the scheduler then returns immediately if it has been less than
self.busy_loop_s since the last read.
Otherwise, enters idle mode and awaits a socket ping for at most
`timeout_ms` milliseconds, or indefinitely if timeout_ms is None.
"""
assert self.is_reader, "Only readers can wait"
current_time = time.monotonic()
if current_time <= self.last_read + self.busy_loop_s:
sched_yield() sched_yield()
else:
events = dict(self.poller.poll(timeout=timeout_ms))
if self.read_cancel_socket in events:
logger.debug("Poller received cancel event")
elif self.local_notify_socket in events:
logger.debug("Poller received notify event")
# Since zmq.CONFLATE is set, there will only be one notification
# to read from the socket
self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False)
else:
logger.debug("Poller timed out")
def notify(self):
"""Notifies all readers to wake up"""
assert not self.is_reader, "Only writers can notify"
self.local_notify_socket.send(b"\x00")
class ShmRingBuffer: class ShmRingBuffer:
...@@ -265,6 +342,7 @@ class Handle: ...@@ -265,6 +342,7 @@ class Handle:
buffer_handle: tuple[int, int, int, str] | None = None buffer_handle: tuple[int, int, int, str] | None = None
local_subscribe_addr: str | None = None local_subscribe_addr: str | None = None
local_notify_addr: str | None = None
remote_subscribe_addr: str | None = None remote_subscribe_addr: str | None = None
remote_addr_ipv6: bool = False remote_addr_ipv6: bool = False
...@@ -288,7 +366,7 @@ class MessageQueue: ...@@ -288,7 +366,7 @@ class MessageQueue:
self.n_local_reader = n_local_reader self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader self.n_remote_reader = n_remote_reader
self.shutting_down = False
context = Context() context = Context()
if n_local_reader > 0: if n_local_reader > 0:
...@@ -310,11 +388,19 @@ class MessageQueue: ...@@ -310,11 +388,19 @@ class MessageQueue:
self.local_socket.bind(local_subscribe_addr) self.local_socket.bind(local_subscribe_addr)
self.current_idx = 0 self.current_idx = 0
# Create the notification side of the SpinCondition
local_notify_addr = get_open_zmq_ipc_path()
self._spin_condition = SpinCondition(
is_reader=False, context=context, notify_address=local_notify_addr
)
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
local_subscribe_addr = None local_subscribe_addr = None
self.local_socket = None self.local_socket = None
self.current_idx = -1 self.current_idx = -1
local_notify_addr = None
self._spin_condition = None # type: ignore
remote_addr_ipv6 = False remote_addr_ipv6 = False
if n_remote_reader > 0: if n_remote_reader > 0:
...@@ -341,12 +427,12 @@ class MessageQueue: ...@@ -341,12 +427,12 @@ class MessageQueue:
self.local_reader_rank = -1 self.local_reader_rank = -1
# rank does not matter for remote readers # rank does not matter for remote readers
self._is_remote_reader = False self._is_remote_reader = False
self._read_spin_timer = SpinTimer()
self.handle = Handle( self.handle = Handle(
local_reader_ranks=local_reader_ranks, local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle() if self.buffer is not None else None, buffer_handle=self.buffer.handle() if self.buffer is not None else None,
local_subscribe_addr=local_subscribe_addr, local_subscribe_addr=local_subscribe_addr,
local_notify_addr=local_notify_addr,
remote_subscribe_addr=remote_subscribe_addr, remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6, remote_addr_ipv6=remote_addr_ipv6,
) )
...@@ -379,9 +465,9 @@ class MessageQueue: ...@@ -379,9 +465,9 @@ class MessageQueue:
self.local_socket.connect(socket_addr) self.local_socket.connect(socket_addr)
self.remote_socket = None self.remote_socket = None
assert isinstance(handle.local_notify_addr, str)
self._read_spin_timer = ( self._spin_condition = SpinCondition(
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() is_reader=True, context=context, notify_address=handle.local_notify_addr
) )
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
...@@ -399,7 +485,9 @@ class MessageQueue: ...@@ -399,7 +485,9 @@ class MessageQueue:
socket_addr = handle.remote_subscribe_addr socket_addr = handle.remote_subscribe_addr
logger.debug("Connecting to %s", socket_addr) logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr) self.remote_socket.connect(socket_addr)
self._spin_condition = None # type: ignore
self.shutting_down = False
return self return self
def wait_until_ready(self): def wait_until_ready(self):
...@@ -435,6 +523,13 @@ class MessageQueue: ...@@ -435,6 +523,13 @@ class MessageQueue:
recv = self.remote_socket.recv() recv = self.remote_socket.recv()
assert recv == b"READY" assert recv == b"READY"
def shutdown(self):
"""If this is an idle reader, wakes it up so it can clean up and shut
down"""
self.shutting_down = True
if self._spin_condition is not None:
self._spin_condition.cancel()
@contextmanager @contextmanager
def acquire_write(self, timeout: float | None = None): def acquire_write(self, timeout: float | None = None):
assert self._is_writer, "Only writers can acquire write" assert self._is_writer, "Only writers can acquire write"
...@@ -465,7 +560,7 @@ class MessageQueue: ...@@ -465,7 +560,7 @@ class MessageQueue:
# if we wait for a long time, log a message # if we wait for a long time, log a message
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
logger.info( logger.info(
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
) )
n_warning += 1 n_warning += 1
...@@ -503,16 +598,60 @@ class MessageQueue: ...@@ -503,16 +598,60 @@ class MessageQueue:
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break break
class ReadTimeoutWithWarnings:
def __init__(self, timeout: float | None, should_warn: bool) -> None:
self.started = time.monotonic()
self.deadline = sys.maxsize if timeout is None else self.started + timeout
# if should_warn, we need to wake up periodically to log
self.warning_wait_time_ms: int | None = (
VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None
)
self._should_warn = should_warn
self.n_warning = 1
self.timeout = timeout
def timeout_ms(self) -> int | None:
"""Returns a timeout that is:
- min(time to deadline, time to next warning) if we're logging warnings
- time to deadline, if we're not logging warnings
- None if the timeout is None and we're not logging warnings
- raise TimeoutError if we are past the deadline
"""
warning_wait_time = self.warning_wait_time_ms
if self.timeout is None:
return warning_wait_time
time_left_ms = int((self.deadline - time.monotonic()) * 1000)
if time_left_ms <= 0:
raise TimeoutError
if warning_wait_time and warning_wait_time < time_left_ms:
return warning_wait_time
return time_left_ms
def should_warn(self) -> bool:
"""Returns true if it's time to log a warning for a timeout that is not
indefinite"""
if self._should_warn:
elapsed = time.monotonic() - self.started
if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning:
self.n_warning += 1
return True
return False
@contextmanager @contextmanager
def acquire_read( def acquire_read(
self, self,
timeout: float | None = None, timeout: float | None = None,
cancel: Event | None = None,
indefinite: bool = False, indefinite: bool = False,
): ):
assert self._is_local_reader, "Only readers can acquire read" assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic() read_timeout = self.ReadTimeoutWithWarnings(
n_warning = 1 timeout=timeout, should_warn=not indefinite
)
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
while True: while True:
# Memory fence ensures we see the latest writes from the writer. # Memory fence ensures we see the latest writes from the writer.
...@@ -529,26 +668,16 @@ class MessageQueue: ...@@ -529,26 +668,16 @@ class MessageQueue:
# for readers, `self.current_idx` is the next block to read # for readers, `self.current_idx` is the next block to read
# if this block is not ready, # if this block is not ready,
# we need to wait until it is written # we need to wait until it is written
self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms())
# Release the processor to other threads if self.shutting_down:
self._read_spin_timer.spin()
if cancel is not None and cancel.is_set():
raise RuntimeError("cancelled") raise RuntimeError("cancelled")
# if we time out, raise an exception
elapsed = time.monotonic() - start_time
if timeout is not None and elapsed > timeout:
raise TimeoutError
# if we wait for a long time, log a message # if we wait for a long time, log a message
if not indefinite and ( if read_timeout.should_warn():
elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
):
logger.info( logger.info(
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
) )
n_warning += 1
continue continue
# found a block that is not read by this reader # found a block that is not read by this reader
...@@ -565,7 +694,7 @@ class MessageQueue: ...@@ -565,7 +694,7 @@ class MessageQueue:
memory_fence() memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
self._read_spin_timer.record_activity() self._spin_condition.record_read()
break break
def enqueue(self, obj, timeout: float | None = None): def enqueue(self, obj, timeout: float | None = None):
...@@ -608,18 +737,19 @@ class MessageQueue: ...@@ -608,18 +737,19 @@ class MessageQueue:
buf[offset:buf_offset] = to_bytes_big(buf_len, 4) buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
self._spin_condition.notify()
if self.n_remote_reader > 0: if self.n_remote_reader > 0:
self.remote_socket.send_multipart(all_buffers, copy=False) self.remote_socket.send_multipart(all_buffers, copy=False)
def dequeue( def dequeue(
self, self,
timeout: float | None = None, timeout: float | None = None,
cancel: Event | None = None,
indefinite: bool = False, indefinite: bool = False,
): ):
"""Read from message queue with optional timeout (in seconds)""" """Read from message queue with optional timeout (in seconds)"""
if self._is_local_reader: if self._is_local_reader:
with self.acquire_read(timeout, cancel, indefinite) as buf: with self.acquire_read(timeout, indefinite) as buf:
overflow = buf[0] == 1 overflow = buf[0] == 1
if not overflow: if not overflow:
offset = 3 offset = 3
......
...@@ -179,7 +179,6 @@ if TYPE_CHECKING: ...@@ -179,7 +179,6 @@ if TYPE_CHECKING:
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
...@@ -1338,9 +1337,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1338,9 +1337,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int(
os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")
), ),
# Reduce CPU usage when vLLM is idle. Enabling this will incur small
# latency penalty when a request eventually comes.
"VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),
# Control the max chunk bytes (in MB) for the rpc message queue. # Control the max chunk bytes (in MB) for the rpc message queue.
# Object larger than this threshold will be broadcast to worker # Object larger than this threshold will be broadcast to worker
# processes via zmq. # processes via zmq.
...@@ -1751,7 +1747,6 @@ def compile_factors() -> dict[str, object]: ...@@ -1751,7 +1747,6 @@ def compile_factors() -> dict[str, object]:
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
"VLLM_SLEEP_WHEN_IDLE",
"VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_IMAGE_FETCH_TIMEOUT",
"VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT",
"VLLM_AUDIO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT",
......
...@@ -104,7 +104,6 @@ class MultiprocExecutor(Executor): ...@@ -104,7 +104,6 @@ class MultiprocExecutor(Executor):
# and ensure workers will be terminated. # and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown) self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: FailureCallback | None = None self.failure_callback: FailureCallback | None = None
tp_size, pp_size, pcp_size = self._get_parallel_sizes() tp_size, pp_size, pcp_size = self._get_parallel_sizes()
...@@ -158,11 +157,14 @@ class MultiprocExecutor(Executor): ...@@ -158,11 +157,14 @@ class MultiprocExecutor(Executor):
global_start_rank = ( global_start_rank = (
self.local_world_size * self.parallel_config.node_rank_within_dp self.local_world_size * self.parallel_config.node_rank_within_dp
) )
# Keep track of socket file descriptors that are inherited by the
# worker when using fork, so that we can close them in subsequent
# workers
inherited_fds: list[int] = []
for local_rank in range(self.local_world_size): for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank global_rank = global_start_rank + local_rank
is_driver_worker = self._is_driver_worker(global_rank) is_driver_worker = self._is_driver_worker(global_rank)
unready_workers.append( unready_worker_handle = WorkerProc.make_worker_process(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
local_rank=local_rank, local_rank=local_rank,
rank=global_rank, rank=global_rank,
...@@ -170,7 +172,15 @@ class MultiprocExecutor(Executor): ...@@ -170,7 +172,15 @@ class MultiprocExecutor(Executor):
input_shm_handle=scheduler_output_handle, input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock, shared_worker_lock=shared_worker_lock,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
) inherited_fds=inherited_fds,
)
unready_workers.append(unready_worker_handle)
if context.get_start_method() == "fork":
inherited_fds.extend(
[
unready_worker_handle.death_writer.fileno(),
unready_worker_handle.ready_pipe.fileno(),
]
) )
# Workers must be created before wait_for_ready to avoid # Workers must be created before wait_for_ready to avoid
...@@ -220,6 +230,7 @@ class MultiprocExecutor(Executor): ...@@ -220,6 +230,7 @@ class MultiprocExecutor(Executor):
for uw in unready_workers: for uw in unready_workers:
if uw.death_writer is not None: if uw.death_writer is not None:
uw.death_writer.close() uw.death_writer.close()
uw.death_writer = None
self._ensure_worker_termination([uw.proc for uw in unready_workers]) self._ensure_worker_termination([uw.proc for uw in unready_workers])
self.output_rank = self._get_output_rank() self.output_rank = self._get_output_rank()
...@@ -255,6 +266,7 @@ class MultiprocExecutor(Executor): ...@@ -255,6 +266,7 @@ class MultiprocExecutor(Executor):
died = multiprocessing.connection.wait(sentinels) died = multiprocessing.connection.wait(sentinels)
_self = self_ref() _self = self_ref()
if not _self or getattr(_self, "shutting_down", False): if not _self or getattr(_self, "shutting_down", False):
logger.debug("MultiprocWorkerMonitor: shutdown already initiated")
return return
_self.is_failed = True _self.is_failed = True
proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0])
...@@ -354,8 +366,6 @@ class MultiprocExecutor(Executor): ...@@ -354,8 +366,6 @@ class MultiprocExecutor(Executor):
if output_rank is not None: if output_rank is not None:
response_mqs = (response_mqs[output_rank],) response_mqs = (response_mqs[output_rank],)
shutdown_event = self.shutdown_event
def get_response(): def get_response():
responses = [] responses = []
for mq in response_mqs: for mq in response_mqs:
...@@ -363,9 +373,7 @@ class MultiprocExecutor(Executor): ...@@ -363,9 +373,7 @@ class MultiprocExecutor(Executor):
None if deadline is None else (deadline - time.monotonic()) None if deadline is None else (deadline - time.monotonic())
) )
try: try:
status, result = mq.dequeue( status, result = mq.dequeue(timeout=dequeue_timeout)
timeout=dequeue_timeout, cancel=shutdown_event
)
except TimeoutError as e: except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e raise TimeoutError(f"RPC call to {method} timed out.") from e
if status != WorkerProc.ResponseStatus.SUCCESS: if status != WorkerProc.ResponseStatus.SUCCESS:
...@@ -408,20 +416,26 @@ class MultiprocExecutor(Executor): ...@@ -408,20 +416,26 @@ class MultiprocExecutor(Executor):
active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()]
# Give processes time to clean themselves up properly first # Give processes time to clean themselves up properly first
logger.debug("Worker Termination: allow workers to gracefully shutdown")
if wait_for_termination(active_procs(), 4): if wait_for_termination(active_procs(), 4):
return return
# Send SIGTERM if still running # Send SIGTERM if still running
logger.debug("Worker Termination: workers still running sending SIGTERM")
for p in active_procs(): for p in active_procs():
p.terminate() p.terminate()
if not wait_for_termination(active_procs(), 4): if not wait_for_termination(active_procs(), 4):
# Send SIGKILL if still running # Send SIGKILL if still running
logger.debug(
"Worker Termination: resorting to SIGKILL to take down workers"
)
for p in active_procs(): for p in active_procs():
p.kill() p.kill()
def shutdown(self): def shutdown(self):
"""Properly shut down the executor and its workers""" """Properly shut down the executor and its workers"""
if not getattr(self, "shutting_down", False): if not getattr(self, "shutting_down", False):
logger.debug("Triggering shutdown of workers")
self.shutting_down = True self.shutting_down = True
# Make sure all the worker processes are terminated first. # Make sure all the worker processes are terminated first.
...@@ -431,12 +445,20 @@ class MultiprocExecutor(Executor): ...@@ -431,12 +445,20 @@ class MultiprocExecutor(Executor):
if w.death_writer is not None: if w.death_writer is not None:
w.death_writer.close() w.death_writer.close()
w.death_writer = None w.death_writer = None
w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in workers]) self._ensure_worker_termination([w.proc for w in workers])
self.shutdown_event.set() for w in workers:
# Shutdown response queues
if w.worker_response_mq is not None:
w.worker_response_mq.shutdown()
w.worker_response_mq = None
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.shutdown()
self.rpc_broadcast_mq = None self.rpc_broadcast_mq = None
for mq in self.response_mqs:
mq.shutdown()
self.response_mqs = []
def check_health(self) -> None: def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10) self.collective_rpc("check_health", timeout=10)
...@@ -609,24 +631,26 @@ class WorkerProc: ...@@ -609,24 +631,26 @@ class WorkerProc:
input_shm_handle, # Receive SchedulerOutput input_shm_handle, # Receive SchedulerOutput
shared_worker_lock: LockType, shared_worker_lock: LockType,
is_driver_worker: bool, is_driver_worker: bool,
inherited_fds: list[int],
) -> UnreadyWorkerProcHandle: ) -> UnreadyWorkerProcHandle:
context = get_mp_context() context = get_mp_context()
# (reader, writer) # Ready pipe to communicate readiness from child to parent
reader, writer = context.Pipe(duplex=False) ready_reader, ready_writer = context.Pipe(duplex=False)
# Death pipe to let child detect parent process exit
# Create death pipe to detect parent process exit
death_reader, death_writer = context.Pipe(duplex=False) death_reader, death_writer = context.Pipe(duplex=False)
process_kwargs = { process_kwargs = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
"local_rank": local_rank, "local_rank": local_rank,
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle, "input_shm_handle": input_shm_handle,
"ready_pipe": (reader, writer), "ready_pipe": ready_writer,
"death_pipe": death_reader, "death_pipe": death_reader,
"shared_worker_lock": shared_worker_lock, "shared_worker_lock": shared_worker_lock,
"is_driver_worker": is_driver_worker, "is_driver_worker": is_driver_worker,
# Have the worker close parent end of this worker's pipes too
"inherited_fds": inherited_fds
+ [ready_reader.fileno(), death_writer.fileno()],
} }
# Run EngineCore busy loop in background process. # Run EngineCore busy loop in background process.
proc = context.Process( proc = context.Process(
...@@ -637,10 +661,12 @@ class WorkerProc: ...@@ -637,10 +661,12 @@ class WorkerProc:
) )
proc.start() proc.start()
writer.close() # Close child ends of pipes here in the parent
ready_writer.close()
death_reader.close()
# Keep death_writer open in parent - when parent exits, # Keep death_writer open in parent - when parent exits,
# death_reader in child will get EOFError # death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer)
@staticmethod @staticmethod
def wait_for_response_handle_ready( def wait_for_response_handle_ready(
...@@ -703,12 +729,41 @@ class WorkerProc: ...@@ -703,12 +729,41 @@ class WorkerProc:
return cast(list[WorkerProcHandle], ready_proc_handles) return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self): def shutdown(self):
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.shutdown()
if self.worker_response_mq is not None:
self.worker_response_mq.shutdown()
self.worker.shutdown() self.worker.shutdown()
self.rpc_broadcast_mq = None self.rpc_broadcast_mq = None
self.worker_response_mq = None self.worker_response_mq = None
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment() destroy_distributed_environment()
def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event):
if death_pipe is None:
return
def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]):
try:
# This will block until parent process exits (pipe closes)
death_pipe.recv()
except EOFError:
logger.info_once("Parent process exited, terminating worker queues")
shutdown_requested.set()
for mq in queues_to_shutdown:
if mq is not None:
mq.shutdown()
except Exception as e:
logger.warning("Death monitoring error: %s", e)
# Pass queue references directly to avoid gc issues if passing self
Thread(
target=death_pipe_monitor,
args=([self.rpc_broadcast_mq, self.worker_response_mq],),
daemon=True,
name="DeathPipeMonitor",
).start()
@staticmethod @staticmethod
def worker_main(*args, **kwargs): def worker_main(*args, **kwargs):
"""Worker initialization and execution loops. """Worker initialization and execution loops.
...@@ -717,12 +772,12 @@ class WorkerProc: ...@@ -717,12 +772,12 @@ class WorkerProc:
# Signal handler used for graceful termination. # Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker # SystemExit exception is only raised once to allow this and worker
# processes to terminate without error # processes to terminate without error
shutdown_requested = False shutdown_requested = threading.Event()
def signal_handler(signum, frame): def signal_handler(signum, frame):
nonlocal shutdown_requested nonlocal shutdown_requested
if not shutdown_requested: if not shutdown_requested.is_set():
shutdown_requested = True shutdown_requested.set()
logger.debug( logger.debug(
"WorkerProc handling signal %d, raising SystemExit", signum "WorkerProc handling signal %d, raising SystemExit", signum
) )
...@@ -733,33 +788,20 @@ class WorkerProc: ...@@ -733,33 +788,20 @@ class WorkerProc:
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
worker = None worker = None
# tuple[Connection, Connection] ready_writer = kwargs.pop("ready_pipe")
reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None)
death_pipe: Connection | None = kwargs.pop("death_pipe", None)
shutdown_event = threading.Event() # Close inherited pipes from parent (incl. other worker pipes)
# Start death monitoring thread if death_pipe is provided # Explicitly passing in existing pipes and closing them makes the pipe
if death_pipe is not None: # behave when using fork. Otherwise, a hidden reference to the pipes
# exist in the child process and prevents EOF closure.
def monitor_parent_death(): for fd in kwargs.pop("inherited_fds", []):
try: try:
# This will block until parent process exits (pipe closes) os.close(fd)
death_pipe.recv()
except EOFError:
# Parent process has exited, terminate this worker
logger.info_once("Parent process exited, terminating worker")
# Send signal to self to trigger clean shutdown
shutdown_event.set()
except Exception as e: except Exception as e:
logger.warning("Death monitoring error: %s", e) logger.warning("Exception closing inherited connection: %s", e)
death_monitor = Thread(
target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor"
)
death_monitor.start()
try: try:
reader.close()
# Initialize tracer # Initialize tracer
rank = kwargs.get("rank", 0) rank = kwargs.get("rank", 0)
maybe_init_worker_tracer( maybe_init_worker_tracer(
...@@ -771,6 +813,8 @@ class WorkerProc: ...@@ -771,6 +813,8 @@ class WorkerProc:
worker = WorkerProc(*args, **kwargs) worker = WorkerProc(*args, **kwargs)
assert worker.worker_response_mq is not None assert worker.worker_response_mq is not None
worker.monitor_death_pipe(death_pipe, shutdown_requested)
# Send READY once we know everything is loaded # Send READY once we know everything is loaded
ready_writer.send( ready_writer.send(
{ {
...@@ -788,7 +832,7 @@ class WorkerProc: ...@@ -788,7 +832,7 @@ class WorkerProc:
ready_writer.close() ready_writer.close()
ready_writer = None ready_writer = None
worker.worker_busy_loop(cancel=shutdown_event) worker.worker_busy_loop()
except Exception: except Exception:
# NOTE: if an Exception arises in busy_loop, we send # NOTE: if an Exception arises in busy_loop, we send
...@@ -798,7 +842,7 @@ class WorkerProc: ...@@ -798,7 +842,7 @@ class WorkerProc:
if ready_writer is not None: if ready_writer is not None:
logger.exception("WorkerProc failed to start.") logger.exception("WorkerProc failed to start.")
elif shutdown_event.is_set(): elif shutdown_requested.is_set():
logger.info("WorkerProc shutting down.") logger.info("WorkerProc shutting down.")
else: else:
logger.exception("WorkerProc failed.") logger.exception("WorkerProc failed.")
...@@ -806,7 +850,7 @@ class WorkerProc: ...@@ -806,7 +850,7 @@ class WorkerProc:
# The parent sends a SIGTERM to all worker processes if # The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw # any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__. # SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True shutdown_requested.set()
except SystemExit as e: except SystemExit as e:
# SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that # SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that
...@@ -859,12 +903,12 @@ class WorkerProc: ...@@ -859,12 +903,12 @@ class WorkerProc:
output = self.async_output_queue.get() output = self.async_output_queue.get()
self.enqueue_output(output) self.enqueue_output(output)
def worker_busy_loop(self, cancel: threading.Event | None = None): def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers""" """Main busy loop for Multiprocessing Workers"""
assert self.rpc_broadcast_mq is not None assert self.rpc_broadcast_mq is not None
while True: while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel, indefinite=True indefinite=True
) )
try: try:
if isinstance(method, str): if isinstance(method, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment