Unverified Commit 15dac210 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[V1] AsyncLLM data parallel (#13923)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 112b3e5b
...@@ -135,12 +135,14 @@ steps: ...@@ -135,12 +135,14 @@ steps:
- examples/offline_inference/rlhf.py - examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py - examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py - tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py
commands: commands:
# test with tp=2 and external_dp=2 # test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py - python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s distributed/test_utils.py - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
...@@ -514,7 +516,10 @@ steps: ...@@ -514,7 +516,10 @@ steps:
- vllm/worker/worker.py - vllm/worker/worker.py
- vllm/worker/model_runner.py - vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py - entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py
- vllm/v1/engine/
commands: commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py - VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
......
...@@ -28,6 +28,7 @@ Multi-node: ...@@ -28,6 +28,7 @@ Multi-node:
--master-port=13345 --master-port=13345
""" """
import os import os
from time import sleep
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import get_open_port from vllm.utils import get_open_port
...@@ -36,14 +37,13 @@ from vllm.utils import get_open_port ...@@ -36,14 +37,13 @@ from vllm.utils import get_open_port
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank): dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# set devices for each dp_rank
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
str(i) # engine processes.
for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
GPUs_per_dp_rank))
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
...@@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, ...@@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}") f"Generated text: {generated_text!r}")
# Give engines time to pause their processing loops before exiting.
sleep(1)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
...@@ -152,8 +155,13 @@ if __name__ == "__main__": ...@@ -152,8 +155,13 @@ if __name__ == "__main__":
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0
for proc in procs: for proc in procs:
proc.join() proc.join(timeout=300)
if proc.exitcode: if proc.exitcode is None:
print(f"Killing process {proc.pid} that "
f"didn't stop within 5 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:
exit_code = proc.exitcode exit_code = proc.exitcode
exit(exit_code) exit(exit_code)
...@@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, ...@@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
core_client: SyncMPClient = client core_client: SyncMPClient = client
result = core_client._call_utility("echo", "testarg") result = core_client.call_utility("echo", "testarg")
assert result == "testarg" assert result == "testarg"
with pytest.raises(Exception) as e_info: with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!") core_client.call_utility("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!" assert str(e_info.value) == "Call to echo method failed: help!"
...@@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): ...@@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
core_client: AsyncMPClient = client core_client: AsyncMPClient = client
result = await core_client._call_utility_async("echo", "testarg") result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg" assert result == "testarg"
with pytest.raises(Exception) as e_info: with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!") await core_client.call_utility_async("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!" assert str(e_info.value) == "Call to echo method failed: help!"
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from contextlib import ExitStack
from typing import Optional
import pytest
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
enforce_eager=True,
disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
)
if not current_platform.supports_v1(engine_args.create_model_config()):
pytest.skip(reason="Requires V1-supporting platform.",
allow_module_level=True)
async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens
await asyncio.sleep(0.)
return count, request_id
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind):
with ExitStack() as after:
prompt = "This is a test of data parallel"
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 10
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
assert not engine.output_processor.has_unfinished_requests()
# testing internals here which may break
core_client: DPAsyncMPClient = engine.engine_core
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for _ in range(10):
if core_client.num_engines_running == 0:
break
await asyncio.sleep(0.5)
assert core_client.num_engines_running == 0
assert not core_client.reqs_in_flight
...@@ -40,7 +40,8 @@ from vllm.transformers_utils.config import ( ...@@ -40,7 +40,8 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, random_uuid, resolve_obj_by_qualname) get_cpu_memory, get_open_port, random_uuid,
resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -1389,6 +1390,8 @@ class ParallelConfig: ...@@ -1389,6 +1390,8 @@ class ParallelConfig:
tensor_parallel_size: int = 1 # Number of tensor parallel groups. tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups. data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group. data_parallel_rank: int = 0 # Rank of the data parallel group.
# Local rank of the data parallel group, defaults to global rank.
data_parallel_rank_local: Optional[int] = None
# IP of the data parallel master. # IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master. data_parallel_master_port: int = 29500 # Port of the data parallel master.
...@@ -1493,10 +1496,18 @@ class ParallelConfig: ...@@ -1493,10 +1496,18 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \ self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size self.tensor_parallel_size
self.data_parallel_size = envs.VLLM_DP_SIZE if self.data_parallel_size > 1:
self.data_parallel_rank = envs.VLLM_DP_RANK # Data parallel was specified in the engine args.
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = get_open_port()
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT # TODO multi-node
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
......
...@@ -15,6 +15,8 @@ import torch ...@@ -15,6 +15,8 @@ import torch
from torch.distributed import ProcessGroup, TCPStore from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore, from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout, _get_default_timeout,
_shutdown_backend,
_unregister_process_group,
is_nccl_available) is_nccl_available)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
...@@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group( ...@@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group(
pg._register_backend(device, backend_type, backend_class) pg._register_backend(device, backend_type, backend_class)
return pg return pg
def stateless_destroy_torch_distributed_process_group(
pg: ProcessGroup) -> None:
"""
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
_shutdown_backend(pg)
_unregister_process_group(pg.group_name)
...@@ -114,6 +114,7 @@ class EngineArgs: ...@@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
data_parallel_size: int = 1
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None block_size: Optional[int] = None
...@@ -442,6 +443,14 @@ class EngineArgs: ...@@ -442,6 +443,14 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.') help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument( parser.add_argument(
'--enable-expert-parallel', '--enable-expert-parallel',
action='store_true', action='store_true',
...@@ -1359,6 +1368,7 @@ class EngineArgs: ...@@ -1359,6 +1368,7 @@ class EngineArgs:
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers, max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce, disable_custom_all_reduce=self.disable_custom_all_reduce,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import hashlib import hashlib
import os import os
import sys
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
...@@ -95,6 +96,7 @@ if TYPE_CHECKING: ...@@ -95,6 +96,7 @@ if TYPE_CHECKING:
VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0 VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1 VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
...@@ -625,6 +627,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -625,6 +627,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_RANK": "VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")), lambda: int(os.getenv("VLLM_DP_RANK", "0")),
# Rank of the process in the data parallel setting.
# Defaults to VLLM_DP_RANK when not set.
"VLLM_DP_RANK_LOCAL":
lambda: int(
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),
# World size of the data parallel setting # World size of the data parallel setting
"VLLM_DP_SIZE": "VLLM_DP_SIZE":
lambda: int(os.getenv("VLLM_DP_SIZE", "1")), lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
......
...@@ -578,7 +578,7 @@ def get_open_port() -> int: ...@@ -578,7 +578,7 @@ def get_open_port() -> int:
dp_port = envs.VLLM_DP_MASTER_PORT dp_port = envs.VLLM_DP_MASTER_PORT
while True: while True:
port = _get_open_port() port = _get_open_port()
if port >= dp_port and port < dp_port + 10: if dp_port <= port < dp_port + 10:
continue continue
return port return port
return _get_open_port() return _get_open_port()
...@@ -2176,11 +2176,11 @@ def make_zmq_socket( ...@@ -2176,11 +2176,11 @@ def make_zmq_socket(
if socket_type == zmq.constants.PULL: if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0) socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size) socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path) socket.bind(path)
elif socket_type == zmq.constants.PUSH: elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0) socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size) socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path) socket.connect(path)
else: else:
raise ValueError(f"Unknown Socket Type: {socket_type}") raise ValueError(f"Unknown Socket Type: {socket_type}")
...@@ -2188,7 +2188,11 @@ def make_zmq_socket( ...@@ -2188,7 +2188,11 @@ def make_zmq_socket(
@contextlib.contextmanager @contextlib.contextmanager
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: def zmq_socket_ctx(
path: str,
socket_type: Any,
linger: int = 0,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket""" """Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined] ctx = zmq.Context() # type: ignore[attr-defined]
...@@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: ...@@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger.debug("Got Keyboard Interrupt.") logger.debug("Got Keyboard Interrupt.")
finally: finally:
ctx.destroy(linger=0) ctx.destroy(linger=linger)
def is_in_ray_actor(): def is_in_ray_actor():
......
...@@ -37,9 +37,10 @@ class Scheduler(SchedulerInterface): ...@@ -37,9 +37,10 @@ class Scheduler(SchedulerInterface):
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
...@@ -48,6 +49,12 @@ class Scheduler(SchedulerInterface): ...@@ -48,6 +49,12 @@ class Scheduler(SchedulerInterface):
self.log_stats = log_stats self.log_stats = log_stats
self.structured_output_manager = structured_output_manager self.structured_output_manager = structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \ self.max_num_scheduled_tokens = \
...@@ -663,10 +670,16 @@ class Scheduler(SchedulerInterface): ...@@ -663,10 +670,16 @@ class Scheduler(SchedulerInterface):
new_running.append(request) new_running.append(request)
self.running = new_running self.running = new_running
return EngineCoreOutputs( engine_core_outputs = EngineCoreOutputs(
outputs=outputs, outputs=outputs,
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(),
) )
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
......
...@@ -128,12 +128,18 @@ class EngineCoreOutputs( ...@@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact, #NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout # e.g. columnwise layout
engine_index: int = 0
# [num_reqs] # [num_reqs]
outputs: list[EngineCoreOutput] = [] outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0 timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the engine is paused.
engine_paused: bool = False
def __post_init__(self): def __post_init__(self):
if self.timestamp == 0.0: if self.timestamp == 0.0:
...@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum): ...@@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
""" """
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
UTILITY = b'\x02' START_DP = b'\x02'
UTILITY = b'\x03'
...@@ -66,11 +66,17 @@ class AsyncLLM(EngineClient): ...@@ -66,11 +66,17 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
self.stat_loggers: list[StatLoggerBase] = []
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats: if self.log_stats:
if logger.isEnabledFor(logging.INFO): for i in range(vllm_config.parallel_config.data_parallel_size):
self.stat_loggers.append(LoggingStatLogger()) loggers: list[StatLoggerBase] = []
self.stat_loggers.append(PrometheusStatLogger(vllm_config)) if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -329,6 +335,7 @@ class AsyncLLM(EngineClient): ...@@ -329,6 +335,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
self._record_stats( self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
...@@ -350,12 +357,13 @@ class AsyncLLM(EngineClient): ...@@ -350,12 +357,13 @@ class AsyncLLM(EngineClient):
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
engine_index: int = 0,
): ):
if not self.log_stats: if not self.log_stats:
return return
assert scheduler_stats is not None assert scheduler_stats is not None
for stat_logger in self.stat_loggers: for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats, stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)
...@@ -393,8 +401,9 @@ class AsyncLLM(EngineClient): ...@@ -393,8 +401,9 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None, scheduler_outputs=None,
model_output=None, model_output=None,
) -> None: ) -> None:
for stat_logger in self.stat_loggers: for loggers in self.stat_loggers:
stat_logger.log() for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None: async def check_health(self) -> None:
logger.debug("Called check_health.") logger.debug("Called check_health.")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import queue import queue
import signal import signal
import sys
import threading import threading
import time import time
from concurrent.futures import Future from concurrent.futures import Future
from inspect import isclass, signature from inspect import isclass, signature
from multiprocessing.connection import Connection from logging import DEBUG
from typing import Any, Optional from typing import Any, Optional
import msgspec import msgspec
...@@ -14,7 +15,9 @@ import psutil ...@@ -14,7 +15,9 @@ import psutil
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm.config import VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
...@@ -91,6 +94,8 @@ class EngineCore: ...@@ -91,6 +94,8 @@ class EngineCore:
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config, lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config, speculative_config=vllm_config.speculative_config,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats, log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager, structured_output_manager=self.structured_output_manager,
) )
...@@ -283,10 +288,10 @@ class EngineCoreProc(EngineCore): ...@@ -283,10 +288,10 @@ class EngineCoreProc(EngineCore):
self, self,
input_path: str, input_path: str,
output_path: str, output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
engine_index: int = 0,
): ):
super().__init__(vllm_config, executor_class, log_stats) super().__init__(vllm_config, executor_class, log_stats)
...@@ -302,14 +307,20 @@ class EngineCoreProc(EngineCore): ...@@ -302,14 +307,20 @@ class EngineCoreProc(EngineCore):
args=(input_path, ), args=(input_path, ),
daemon=True).start() daemon=True).start()
threading.Thread(target=self.process_output_socket, threading.Thread(target=self.process_output_socket,
args=(output_path, ), args=(output_path, engine_index),
daemon=True).start() daemon=True).start()
# Send Readiness signal to EngineClient. self.global_unfinished_reqs = False
ready_pipe.send({"status": "READY"})
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod @staticmethod
def run_engine_core(*args, **kwargs): def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process.""" """Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination. # Signal handler used for graceful termination.
...@@ -331,9 +342,21 @@ class EngineCoreProc(EngineCore): ...@@ -331,9 +342,21 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
engine_core = None engine_core: Optional[EngineCoreProc] = None
try: try:
engine_core = EngineCoreProc(*args, **kwargs) parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop() engine_core.run_busy_loop()
except SystemExit: except SystemExit:
...@@ -351,28 +374,44 @@ class EngineCoreProc(EngineCore): ...@@ -351,28 +374,44 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
step_fn = (self.step
if self.batch_queue is None else self.step_with_batch_queue)
# Loop until process is sent a SIGINT or SIGTERM # Loop until process is sent a SIGINT or SIGTERM
while True: while True:
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
while not self.scheduler.has_requests(): self._process_input_queue()
logger.debug("EngineCore busy loop waiting.") # 2) Step the engine core and return the outputs.
req = self.input_queue.get() self._process_engine_step()
self._handle_client_request(*req)
def _process_input_queue(self):
# 2) Handle any new client requests. """Exits when an engine step needs to be performed."""
while not self.input_queue.empty():
req = self.input_queue.get_nowait() waited = False
self._handle_client_request(*req) while not self.global_unfinished_reqs and not (
self.scheduler.has_requests()):
# 3) Step the engine core. if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
outputs = step_fn() logger.debug("EngineCore waiting for work.")
waited = True
# 4) Put EngineCoreOutputs into the output queue. req = self.input_queue.get()
if outputs is not None: self._handle_client_request(*req)
self.output_queue.put_nowait(outputs)
if waited:
logger.debug(
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self):
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request_type: EngineCoreRequestType, def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
...@@ -382,6 +421,10 @@ class EngineCoreProc(EngineCore): ...@@ -382,6 +421,10 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT: elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY: elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request call_id, method_name, args = request
output = UtilityOutput(call_id) output = UtilityOutput(call_id)
...@@ -432,7 +475,7 @@ class EngineCoreProc(EngineCore): ...@@ -432,7 +475,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str): def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread.""" """Output socket IO thread."""
# Msgpack serialization encoding. # Msgpack serialization encoding.
...@@ -443,5 +486,114 @@ class EngineCoreProc(EngineCore): ...@@ -443,5 +486,114 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True: while True:
outputs = self.output_queue.get() outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer) encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False) socket.send(buffer, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.platforms.cuda import device_id_to_physical_device_id
tp_size = vllm_config.parallel_config.tensor_parallel_size
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if local_unfinished_reqs:
# 2) Step the engine core.
self._process_engine_step()
# Check if we have now finished all requests.
local_unfinished_reqs = (
self.scheduler.has_unfinished_requests())
else:
if self.scheduler.has_finished_requests():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self._process_engine_step()
if not self.global_unfinished_reqs:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.global_unfinished_reqs:
# Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self.counter += 1
if self.counter != 16:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
...@@ -8,10 +8,11 @@ import threading ...@@ -8,10 +8,11 @@ import threading
import uuid import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Union
import zmq import zmq
import zmq.asyncio import zmq.asyncio
...@@ -60,6 +61,9 @@ class EngineCoreClient(ABC): ...@@ -60,6 +61,9 @@ class EngineCoreClient(ABC):
"is not currently supported.") "is not currently supported.")
if multiprocess_mode and asyncio_mode: if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats) return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode: if multiprocess_mode and not asyncio_mode:
...@@ -207,28 +211,74 @@ class InprocClient(EngineCoreClient): ...@@ -207,28 +211,74 @@ class InprocClient(EngineCoreClient):
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
class CoreEngine:
"""One per data parallel rank."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0
finally:
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass @dataclass
class BackgroundResources: class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding """Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object.""" circular reference back to the client object."""
ctx: zmq.Context ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
def __call__(self): def __call__(self):
"""Clean up background resources.""" """Clean up background resources."""
if self.proc_handle is not None: for core_engine in self.core_engines:
self.proc_handle.shutdown() core_engine.close()
# ZMQ context termination can hang if the sockets # ZMQ context termination can hang if the sockets
# aren't explicitly closed first. # aren't explicitly closed first.
if self.output_socket is not None: if self.output_socket is not None:
self.output_socket.close(linger=0) self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None: if self.shutdown_path is not None:
# We must ensure that the sync output socket is # We must ensure that the sync output socket is
# closed cleanly in its own thread. # closed cleanly in its own thread.
...@@ -284,7 +334,7 @@ class MPClient(EngineCoreClient): ...@@ -284,7 +334,7 @@ class MPClient(EngineCoreClient):
self.decoder = MsgpackDecoder(EngineCoreOutputs) self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
sync_ctx = zmq.Context() sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed # This will ensure resources created so far are closed
...@@ -293,28 +343,38 @@ class MPClient(EngineCoreClient): ...@@ -293,28 +343,38 @@ class MPClient(EngineCoreClient):
self.resources = BackgroundResources(ctx=sync_ctx) self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources) self._finalizer = weakref.finalize(self, self.resources)
# Paths for IPC. # Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path() self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
# Start EngineCore in background process. new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
self.resources.proc_handle = BackgroundProcHandle( vllm_config, executor_class, log_stats, self.ctx, self.output_path,
input_path=input_path, index, local_dp_rank)
output_path=self.output_path,
process_name="EngineCore", # Start engine core process(es).
target_fn=EngineCoreProc.run_engine_core, self._init_core_engines(vllm_config, new_core_engine,
process_kwargs={ self.resources.core_engines)
"vllm_config": vllm_config,
"executor_class": executor_class, # Wait for engine core process(es) to start.
"log_stats": log_stats, for engine in self.resources.core_engines:
}) engine.proc_handle.wait_for_startup()
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Default case - single core engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
core_engine = new_core_engine(
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
core_engines.append(core_engine)
self.core_engine = core_engine
def shutdown(self): def shutdown(self):
self._finalizer() self._finalizer()
...@@ -370,7 +430,7 @@ class SyncMPClient(MPClient): ...@@ -370,7 +430,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread. # shutdown signal, exit thread.
break break
(frame, ) = out_socket.recv_multipart(copy=False) frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer) outputs = decoder.decode(frame.buffer)
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
...@@ -391,18 +451,15 @@ class SyncMPClient(MPClient): ...@@ -391,18 +451,15 @@ class SyncMPClient(MPClient):
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get() return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(self, request_type: EngineCoreRequestType, request: Any):
request: Any) -> None:
# (RequestType, SerializedRequest) # (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False) self.core_engine.send_multipart(msg)
def _call_utility(self, method: str, *args) -> Any: def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future() future: Future[Any] = Future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY, self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args)) (call_id, method, args))
...@@ -419,34 +476,34 @@ class SyncMPClient(MPClient): ...@@ -419,34 +476,34 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self._call_utility("profile", is_start) self.call_utility("profile", is_start)
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self._call_utility("reset_prefix_cache") self.call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self._call_utility("add_lora", lora_request) return self.call_utility("add_lora", lora_request)
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id) return self.call_utility("remove_lora", lora_id)
def list_loras(self) -> set[int]: def list_loras(self) -> set[int]:
return self._call_utility("list_loras") return self.call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self._call_utility("pin_lora", lora_id) return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level) self.call_utility("sleep", level)
def wake_up(self) -> None: def wake_up(self) -> None:
self._call_utility("wake_up") self.call_utility("wake_up")
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self._call_utility("is_sleeping") return self.call_utility("is_sleeping")
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch") self.call_utility("execute_dummy_batch")
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
...@@ -464,13 +521,21 @@ class AsyncMPClient(MPClient): ...@@ -464,13 +521,21 @@ class AsyncMPClient(MPClient):
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self): self.outputs_handler: Optional[Callable[
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
def _ensure_output_queue_task(self):
if self.outputs_queue is not None:
return
# Perform IO in separate task to parallelize as much as possible. # Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client. # Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue() self.outputs_queue = asyncio.Queue()
decoder = self.decoder decoder = self.decoder
utility_results = self.utility_results utility_results = self.utility_results
outputs_queue = self.outputs_queue outputs_queue = self.outputs_queue
output_handler = self.outputs_handler
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path, output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL) zmq.constants.PULL)
...@@ -483,34 +548,52 @@ class AsyncMPClient(MPClient): ...@@ -483,34 +548,52 @@ class AsyncMPClient(MPClient):
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
utility_results) utility_results)
else: continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs) outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket(), self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask") name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None: self._ensure_output_queue_task()
await self._start_output_queue_task() assert self.outputs_queue is not None
assert self.outputs_queue is not None
return await self.outputs_queue.get() return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
msg = (request_type.value, self.encoder.encode(request)) self._ensure_output_queue_task()
await self.input_socket.send_multipart(msg, copy=False)
if self.outputs_queue is None: async def call_utility_async(self, method: str, *args) -> Any:
await self._start_output_queue_task() return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args) -> Any: async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
call_id = uuid.uuid1().int >> 64 call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future() future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY, message = (EngineCoreRequestType.UTILITY.value,
(call_id, method, args)) self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
self._ensure_output_queue_task()
return await future return await future
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
...@@ -524,31 +607,146 @@ class AsyncMPClient(MPClient): ...@@ -524,31 +607,146 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._call_utility_async("profile", is_start) await self.call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self._call_utility_async("reset_prefix_cache") await self.call_utility_async("reset_prefix_cache")
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
await self._call_utility_async("sleep", level) await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None: async def wake_up_async(self) -> None:
await self._call_utility_async("wake_up") await self.call_utility_async("wake_up")
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
return await self._call_utility_async("is_sleeping") return await self.call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None: async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch") await self.call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> bool: async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self._call_utility_async("add_lora", lora_request) return await self.call_utility_async("add_lora", lora_request)
async def remove_lora_async(self, lora_id: int) -> bool: async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id) return await self.call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> set[int]: async def list_loras_async(self) -> set[int]:
return await self._call_utility_async("list_loras") return await self.call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool: async def pin_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("pin_lora", lora_id) return await self.call_utility_async("pin_lora", lora_id)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[
self._call_utility_async(method, *args, engine=engine)
for engine in self.core_engines
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
])
self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs):
if self.reqs_in_flight:
for req_id in outputs.finished_requests or ():
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1
if outputs.engine_paused:
assert self.num_engines_running >= 1
self.num_engines_running -= 1
if not self.num_engines_running and self.reqs_in_flight:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
if coros:
await asyncio.gather(*coros)
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:
return
if len(request_ids) == 1:
# Fast-path common case.
if engine := self.reqs_in_flight.get(request_ids[0]):
await self._abort_requests(request_ids, engine)
return
by_engine: dict[CoreEngine, list[str]] = {}
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine.setdefault(engine, []).append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
...@@ -8,6 +8,7 @@ from typing_extensions import TypeVar ...@@ -8,6 +8,7 @@ from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
...@@ -60,11 +61,13 @@ class LLMEngine: ...@@ -60,11 +61,13 @@ class LLMEngine:
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
# important: init dp group before init the engine_core # important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config # In the decoupled engine case this is handled in EngineCoreProc.
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
if self.dp_enabled:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
...@@ -148,7 +151,7 @@ class LLMEngine: ...@@ -148,7 +151,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests() has_unfinished = self.output_processor.has_unfinished_requests()
if not self.dp_enabled: if self.dp_group is None:
return has_unfinished return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished) return self.has_unfinished_requests_dp(has_unfinished)
...@@ -280,3 +283,7 @@ class LLMEngine: ...@@ -280,3 +283,7 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted.""" """Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id) return self.engine_core.pin_lora(lora_id)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
...@@ -235,7 +235,10 @@ class WorkerProc: ...@@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle = self.worker_response_mq.export_handle() worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process. # Send Readiness signal to EngineCore process.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: # Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle, payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL) protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send_string(WorkerProc.READY_STR)
...@@ -270,11 +273,13 @@ class WorkerProc: ...@@ -270,11 +273,13 @@ class WorkerProc:
proc = context.Process(target=WorkerProc.worker_main, proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs, kwargs=process_kwargs,
daemon=True) daemon=True)
proc.start()
# Wait for startup with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
worker_response_mq_handle = WorkerProc.wait_for_startup( proc.start()
proc, ready_path)
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_socket)
worker_response_mq = MessageQueue.create_from_handle( worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0) worker_response_mq_handle, 0)
...@@ -337,23 +342,22 @@ class WorkerProc: ...@@ -337,23 +342,22 @@ class WorkerProc:
@staticmethod @staticmethod
def wait_for_startup( def wait_for_startup(
proc: BaseProcess, proc: BaseProcess,
ready_path: str, ready_socket: zmq.Socket,
) -> Optional[Handle]: ) -> Optional[Handle]:
"""Wait until the Worker is ready.""" """Wait until the Worker is ready."""
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY. # Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.") logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive(): if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.") raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string() message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False) handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer) handle = pickle.loads(handle_frame.buffer)
return handle return handle
class ResponseStatus(Enum): class ResponseStatus(Enum):
SUCCESS = auto() SUCCESS = auto()
......
...@@ -31,7 +31,8 @@ class StatLoggerBase(ABC): ...@@ -31,7 +31,8 @@ class StatLoggerBase(ABC):
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
def __init__(self): def __init__(self, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic()) self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats() self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset. # Prefix cache metrics. This cannot be reset.
...@@ -78,11 +79,13 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -78,11 +79,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output. # Format and print output.
logger.info( logger.info(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, " "Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, " "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
scheduler_stats.num_running_reqs, scheduler_stats.num_running_reqs,
...@@ -94,7 +97,7 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -94,7 +97,7 @@ class LoggingStatLogger(StatLoggerBase):
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics() self._unregister_vllm_metrics()
# Use this flag to hide metrics that were deprecated in # Use this flag to hide metrics that were deprecated in
...@@ -102,8 +105,11 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -102,8 +105,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.show_hidden_metrics = \ self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name"] labelnames = ["model_name", "engine"]
labelvalues = [vllm_config.model_config.served_model_name] labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
......
...@@ -105,7 +105,7 @@ class BackgroundProcHandle: ...@@ -105,7 +105,7 @@ class BackgroundProcHandle:
process_kwargs: dict[Any, Any], process_kwargs: dict[Any, Any],
): ):
context = get_mp_context() context = get_mp_context()
reader, writer = context.Pipe(duplex=False) self.reader, writer = context.Pipe(duplex=False)
assert ("ready_pipe" not in process_kwargs assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs and "input_path" not in process_kwargs
...@@ -115,14 +115,17 @@ class BackgroundProcHandle: ...@@ -115,14 +115,17 @@ class BackgroundProcHandle:
process_kwargs["output_path"] = output_path process_kwargs["output_path"] = output_path
# Run busy loop in background process. # Run busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs) self.proc = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self._finalizer = weakref.finalize(self, shutdown, self.proc, self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path) input_path, output_path)
self.proc.start() self.proc.start()
def wait_for_startup(self):
# Wait for startup. # Wait for startup.
if reader.recv()["status"] != "READY": if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. " raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.") "See root cause above.")
def shutdown(self): def shutdown(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment