Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
from collections import namedtuple from typing import Any, Dict, Optional, Union
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup import torch.distributed
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, from .parallel_state import get_tp_group
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_ca_communicator,
get_tp_pynccl_communicator)
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
ca_comm = get_tp_ca_communicator()
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
tp_pynccl_comm = get_tp_pynccl_communicator()
pp_pynccl_comm = get_pp_pynccl_communicator()
if not tp_pynccl_comm:
maybe_tp_pynccl_context = nullcontext()
else:
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
if not pp_pynccl_comm:
maybe_pp_pynccl_context = nullcontext()
else:
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
yield graph_capture_context
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group. """All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm = get_tp_ca_communicator()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_: torch.Tensor, def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group.""" """All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size() return get_tp_group().all_gather(input_, dim)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def tensor_model_parallel_gather(input_: torch.Tensor, def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0, dst: int = 0,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group. """Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src, group=group)
return input_
def broadcast_object_list(obj_list: List[Any], def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
src: int = 0, Any]]] = None,
group: Optional[ProcessGroup] = None): src: int = 0):
"""Broadcast the input object list.""" if not torch.distributed.is_initialized():
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = "cpu" if value.is_cpu else "cuda"
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=metadata_group)
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
import glob
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
from vllm.logger import init_logger
logger = init_logger(__name__)
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
def get_pytorch_default_cudart_library_path() -> str:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
lib_folder = "cuda_runtime"
lib_name = "libcudart.so.*[0-9]"
lib_path = None
for path in sys.path:
nvidia_path = os.path.join(path, "nvidia")
if not os.path.exists(nvidia_path):
continue
candidate_lib_paths = glob.glob(
os.path.join(nvidia_path, lib_folder, "lib", lib_name))
if candidate_lib_paths and not lib_path:
lib_path = candidate_lib_paths[0]
if lib_path:
break
if not lib_path:
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
return lib_path
class CudaRTLibrary:
exported_functions = [
# ​cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# ​cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# ​cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = get_pytorch_default_cudart_library_path()
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
return handle
def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
return devPtr
...@@ -9,80 +9,20 @@ import vllm.envs as envs ...@@ -9,80 +9,20 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import ( from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check) gpu_p2p_access_check)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import in_the_same_node_as
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import cuda_device_count_stateless, is_full_nvlink
try: try:
if (not is_hip()): assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
import pynvml custom_ar = True
except Exception:
# Simulate ImportError if custom_ar ops are not supported. # For AMD GPUs and CPUs
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
raise ImportError("custom_ar", __file__)
custom_ar = True
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
else:
custom_ar = False
pynvml = None
@contextmanager
def _nvml():
try:
yield
finally:
pass
except ImportError:
# For AMD GPUs
custom_ar = False custom_ar = False
pynvml = None
@contextmanager
def _nvml():
try:
yield
finally:
pass
logger = init_logger(__name__) logger = init_logger(__name__)
@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool: def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size): for i in range(world_size):
if i == rank: if i == rank:
...@@ -98,8 +38,8 @@ class CustomAllreduce: ...@@ -98,8 +38,8 @@ class CustomAllreduce:
# max_size: max supported allreduce size # max_size: max supported allreduce size
def __init__(self, def __init__(self,
group: Optional[ProcessGroup] = None, group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None, device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None: max_size=8192 * 1024) -> None:
""" """
Args: Args:
...@@ -119,13 +59,12 @@ class CustomAllreduce: ...@@ -119,13 +59,12 @@ class CustomAllreduce:
# e.g. in a non-cuda environment # e.g. in a non-cuda environment
return return
group = group or get_tensor_model_parallel_cpu_group()
self.group = group self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.") "CustomAllreduce should be attached to a non-NCCL group.")
if not is_in_the_same_node(group): if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case. # No need to initialize custom allreduce for multi-node case.
logger.warning( logger.warning(
"Custom allreduce is disabled because this process group" "Custom allreduce is disabled because this process group"
...@@ -146,10 +85,7 @@ class CustomAllreduce: ...@@ -146,10 +85,7 @@ class CustomAllreduce:
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return return
if device is None: if isinstance(device, int):
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
...@@ -161,7 +97,7 @@ class CustomAllreduce: ...@@ -161,7 +97,7 @@ class CustomAllreduce:
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(torch.cuda.device_count())) device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index] physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], tensor = torch.tensor([physical_device_id],
...@@ -177,7 +113,7 @@ class CustomAllreduce: ...@@ -177,7 +113,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
# this checks hardware and driver support for NVLink # this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(physical_device_ids) full_nvlink = is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warning( logger.warning(
"Custom allreduce is disabled because it's not supported on" "Custom allreduce is disabled because it's not supported on"
......
import ctypes
import json import json
import os import os
import pickle
import subprocess
import sys import sys
import tempfile from itertools import product
import time from typing import Dict, List, Optional, Sequence
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional
import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
logger = init_logger(__name__) logger = init_logger(__name__)
@contextmanager def producer(batch_src: Sequence[int],
def mute_output(): producer_queue,
with open(os.devnull, "w") as f: consumer_queue,
sys.stderr = f result_queue,
sys.stdout = f
yield
def producer(i: int,
init_method: str,
cuda_visible_devices: Optional[str] = None): cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None: if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices update_environment_variables(
with mute_output(): {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
dist.init_process_group(
backend="gloo", lib = CudaRTLibrary()
init_method=init_method, for i in batch_src:
world_size=2, lib.cudaSetDevice(i)
rank=0, pointer = lib.cudaMalloc(1024)
) lib.cudaMemset(pointer, 1, 1024)
# produce a tensor in GPU i lib.cudaDeviceSynchronize()
data = torch.zeros((128, ), device=f"cuda:{i}") handle = lib.cudaIpcGetMemHandle(pointer)
# get the information to reconstruct the shared tensor producer_queue.put(handle)
func, args = torch.multiprocessing.reductions.reduce_tensor(data) open_success = consumer_queue.get()
args = list(args) if open_success:
dist.broadcast_object_list([(func, args)], src=0) # use two queues to simulate barrier
dist.barrier() producer_queue.put(0)
torch.cuda.synchronize() consumer_queue.get()
assert torch.all(data == 1).item() # check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
def consumer(j: int, for i in range(1024):
init_method: str, if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def consumer(batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None): cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None: if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices update_environment_variables(
with mute_output(): {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
dist.init_process_group(
backend="gloo", lib = CudaRTLibrary()
init_method=init_method, for j in batch_tgt:
world_size=2, lib.cudaSetDevice(j)
rank=1, handle = producer_queue.get()
) open_success = False
torch.cuda.set_device(j) try:
recv = [None] pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
dist.broadcast_object_list(recv, src=0) open_success = True
func: Callable except RuntimeError:
args: List # cannot error out here, because the producer process
func, args = recv[0] # type: ignore # is still waiting for the response.
# `args[6]` is the device id pass
# by default pytorch will use `i` from the producer consumer_queue.put(open_success)
# here we need to set it to `j` to test P2P access if open_success:
args[6] = j # modify the memory
data = func(*args) lib.cudaMemset(pointer, 2, 1024)
data += 1 lib.cudaDeviceSynchronize()
dist.barrier() # use two queues to simulate barrier
torch.cuda.synchronize() producer_queue.get()
assert torch.all(data == 1).item() consumer_queue.put(0)
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
def can_actually_p2p(i, j): lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def can_actually_p2p(
batch_src: Sequence[int],
batch_tgt: Sequence[int],
) -> Sequence[bool]:
""" """
Usually, checking if P2P access is enabled can be done by Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)` the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible. returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
...@@ -90,41 +108,55 @@ def can_actually_p2p(i, j): ...@@ -90,41 +108,55 @@ def can_actually_p2p(i, j):
Note on p2p and cuda IPC: Note on p2p and cuda IPC:
Usually, one process uses one GPU: Usually, one process uses one GPU:
GPU i --> cuda context i --> tensor i --> process i GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that: We need to combine p2p and cuda IPC, so that:
GPU i --> cuda context i --> tensor i --> process i GPU src --> cuda context src --> tensor src --> process src
|shared| |shared|
GPU j --> cuda context j --> tensor j --> process j GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process i creates a tensor in GPU i, passes IPC handle to That is to say, process src creates a tensor in GPU src, passes IPC handle to
process j, and process j accesses the tensor in GPU j. Any operation on the process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process j will be reflected in the tensor in process i, because tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment. they are the same memory segment.
It is important to note that process j accesses the tensor in GPU j, not It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU i. That's why we need p2p access. # noqa GPU src. That's why we need p2p access.
"""
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process # pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs # to make sure they see the same set of GPUs
# make sure the temp file is not the same across different calls
temp_path = tempfile.mktemp() + str(time.time())
# create an empty file
with open(temp_path, "w"):
pass
init_method = f"file://{temp_path}"
# make sure the processes are spawned # make sure the processes are spawned
smp = mp.get_context("spawn") smp = mp.get_context("spawn")
pi = smp.Process(target=producer, producer_queue = smp.Queue()
args=(i, init_method, cuda_visible_devices)) consumer_queue = smp.Queue()
pj = smp.Process(target=consumer, result_queue = smp.Queue()
args=(j, init_method, cuda_visible_devices)) p_src = smp.Process(target=producer,
pi.start() args=(batch_src, producer_queue, consumer_queue,
pj.start() result_queue, cuda_visible_devices))
pi.join() p_tgt = smp.Process(target=consumer,
pj.join() args=(batch_tgt, producer_queue, consumer_queue,
return pi.exitcode == 0 and pj.exitcode == 0 result_queue, cuda_visible_devices))
p_src.start()
p_tgt.start()
p_src.join()
p_tgt.join()
result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
if a != b:
logger.warning(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled.", src, tgt)
result.append(False)
else:
result.append(a)
return result
# why do we need this cache? # why do we need this cache?
...@@ -142,18 +174,18 @@ def can_actually_p2p(i, j): ...@@ -142,18 +174,18 @@ def can_actually_p2p(i, j):
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool: def gpu_p2p_access_check(src: int, tgt: int) -> bool:
"""Check if GPU i can access GPU j.""" """Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated, # if the cache variable is already calculated,
# read from the cache instead of checking it again # read from the cache instead of checking it again
global _gpu_p2p_access_cache global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None: if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"] return _gpu_p2p_access_cache[f"{src}->{tgt}"]
is_distributed = dist.is_initialized() is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count() num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
...@@ -162,25 +194,51 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -162,25 +194,51 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
) )
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
if ((not is_distributed or get_local_rank() == 0) from vllm.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0)
and (not os.path.exists(path))): and (not os.path.exists(path))):
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path) logger.info("generating GPU P2P access cache in %s", path)
cache = {} cache: Dict[str, bool] = {}
for _i in range(num_dev): ids = list(range(num_dev))
for _j in range(num_dev): # batch of all pairs of GPUs
cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j) batch_src, batch_tgt = zip(*list(product(ids, ids)))
# NOTE: we use `subprocess` rather than `multiprocessing` here
# because the caller might not have `if __name__ == "__main__":`,
# in that case we cannot use spawn method in multiprocessing.
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
input_bytes = pickle.dumps((batch_src, batch_tgt))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}") from e
result = pickle.loads(returned.stdout)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f: with open(path, "w") as f:
json.dump(cache, f, indent=4) json.dump(cache, f, indent=4)
if is_distributed: if is_distributed:
cpu_world_group = get_cpu_world_group() get_world_group().barrier()
dist.barrier(cpu_world_group)
logger.info("reading GPU P2P access cache from %s", path) logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f: with open(path, "r") as f:
cache = json.load(f) cache = json.load(f)
_gpu_p2p_access_cache = cache _gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"] return _gpu_p2p_access_cache[f"{src}->{tgt}"]
__all__ = ["gpu_p2p_access_check"] __all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
sys.stdout.buffer.write(pickle.dumps(result))
...@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp ...@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId) ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -19,8 +18,8 @@ class PyNcclCommunicator: ...@@ -19,8 +18,8 @@ class PyNcclCommunicator:
def __init__( def __init__(
self, self,
group: Optional[ProcessGroup] = None, group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None, device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
): ):
""" """
...@@ -35,7 +34,6 @@ class PyNcclCommunicator: ...@@ -35,7 +34,6 @@ class PyNcclCommunicator:
is bind to a unique device. is bind to a unique device.
""" """
assert dist.is_initialized() assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.") "PyNcclCommunicator should be attached to a non-NCCL group.")
self.group = group self.group = group
...@@ -77,10 +75,7 @@ class PyNcclCommunicator: ...@@ -77,10 +75,7 @@ class PyNcclCommunicator:
byte_list = tensor.tolist() byte_list = tensor.tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
if device is None: if isinstance(device, int):
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
...@@ -126,10 +121,7 @@ class PyNcclCommunicator: ...@@ -126,10 +121,7 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream)) cudaStream_t(stream.cuda_stream))
def send(self, def send(self, tensor: torch.Tensor, dst: int, stream=None):
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
if self.disabled: if self.disabled:
return return
assert tensor.device == self.device, ( assert tensor.device == self.device, (
...@@ -137,16 +129,11 @@ class PyNcclCommunicator: ...@@ -137,16 +129,11 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst, ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
def recv(self, def recv(self, tensor: torch.Tensor, src: int, stream=None):
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
if self.disabled: if self.disabled:
return return
assert tensor.device == self.device, ( assert tensor.device == self.device, (
...@@ -154,8 +141,6 @@ class PyNcclCommunicator: ...@@ -154,8 +141,6 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src, ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
......
...@@ -205,7 +205,7 @@ class NCCLLibrary: ...@@ -205,7 +205,7 @@ class NCCLLibrary:
raise e raise e
if so_file not in NCCLLibrary.path_to_dict_mapping: if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs = {} _funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions: for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name) f = getattr(self.lib, func.name)
f.restype = func.restype f.restype = func.restype
......
import pickle
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from unittest.mock import patch
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7
logger = init_logger(__name__)
class ShmRingBuffer:
def __init__(self,
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None):
"""
A shared memory ring buffer implementation for broadcast communication.
Essentially, it is a queue where only one will `enqueue` and multiple
will `dequeue`. The max size of each item, together with the max number
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
| (current_idx) | (current_idx)
v v
+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
metadata memory layout: each byte is a flag, the first byte is the written
flag, and the rest are reader flags. The flags are set to 0 by default.
+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+
The state of metadata is as follows:
(case 1) 0???...???: the block is not written yet, cannot read, can write
(case 2) 1000...000: the block is just written, can read, cannot write
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
State transition for readers:
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
Only after the caller finishes reading the block, the reader can mark the block as read.
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
State transition for writer:
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
self.n_reader = n_reader
self.metadata_size = 1 + n_reader
self.max_chunk_bytes = max_chunk_bytes
self.max_chunks = max_chunks
self.total_bytes_of_buffer = (self.max_chunk_bytes +
self.metadata_size) * self.max_chunks
self.data_offset = 0
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
if name is None:
# we are creating a buffer
self.is_creator = True
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.total_bytes_of_buffer)
# initialize the metadata section to 0
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
else:
# we are opening an existing buffer
self.is_creator = False
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
self.shared_memory = shared_memory.SharedMemory(name=name)
assert self.shared_memory.size == self.total_bytes_of_buffer
def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
)
def __del__(self):
self.shared_memory.close()
if self.is_creator:
self.shared_memory.unlink()
@contextmanager
def get_data(self, current_idx: int):
start = self.data_offset + current_idx * self.max_chunk_bytes
end = start + self.max_chunk_bytes
with memoryview(self.shared_memory.buf[start:end]) as buf:
yield buf
@contextmanager
def get_metadata(self, current_idx: int):
start = self.metadata_offset + current_idx * self.metadata_size
end = start + self.metadata_size
with memoryview(self.shared_memory.buf[start:end]) as buf:
yield buf
@dataclass
class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)
buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue:
def __init__(
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,
):
if local_reader_ranks is None:
local_reader_ranks = list(range(n_local_reader))
else:
assert len(local_reader_ranks) == n_local_reader
self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
if connect_ip is None:
connect_ip = get_ip()
context = Context()
if n_local_reader > 0:
# for local readers, we will:
# 1. create a shared memory ring buffer to communicate small data
# 2. create a publish-subscribe socket to communicate large data
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
self.local_socket = context.socket(PUB)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True
self._is_local_reader = False
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)
def export_handle(self) -> Handle:
return self.handle
@staticmethod
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self = MessageQueue.__new__(MessageQueue)
self.handle = handle
self._is_writer = False
context = Context()
if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
self._is_remote_reader = False
self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "")
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
self.local_reader_rank = -1
self._is_local_reader = False
self._is_remote_reader = True
self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self
def wait_until_ready(self):
"""This is a collective operation. All processes (including the
readers and the writer) should call this function.
"""
if self._is_writer:
# wait for all readers to connect
# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
if self.n_local_reader > 0:
self.local_socket.send(b"READY")
# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
if self.n_remote_reader > 0:
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
recv = self.remote_socket.recv()
assert recv == b"READY"
@contextmanager
def acquire_write(self):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# for writers, `self.current_idx` is the next block to write
# if this block is not ready to write,
# we need to wait until it is read by all readers
# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue
# found a block that is either
# (1) not written
# (2) read by all readers
# mark the block as not written
metadata_buffer[0] = 0
# let caller write to the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has written to the buffer
# NOTE: order is important here
# first set the read flags to 0
# then set the written flag to 1
# otherwise, the readers may think they already read the block
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break
@contextmanager
def acquire_read(self):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
# this block is either
# (1) not written
# (2) already read by this reader
# for readers, `self.current_idx` is the next block to read
# if this block is not ready,
# we need to wait until it is written
# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# if we wait for a long time, we should warn the user
if time.monotonic(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue
# found a block that is not read by this reader
# let caller read from the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break
def enqueue(self, obj):
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self):
if self._is_local_reader:
overflow = False
with self.acquire_read() as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
# pickle format contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
else:
raise RuntimeError("Only readers can dequeue")
return obj
def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
return obj
else:
return self.dequeue()
@staticmethod
def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io: MessageQueue
if group_rank == writer_rank:
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
else:
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io
...@@ -2,81 +2,795 @@ ...@@ -2,81 +2,795 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.
If you only need to use the distributed environment without model/pipeline
parallelism, you can skip the model parallel initialization and destruction
steps.
"""
import contextlib import contextlib
from multiprocessing import resource_tracker, shared_memory import pickle
from typing import List, Optional from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch import torch
from torch.distributed import ProcessGroup import torch.distributed
from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True @dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
# Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
_TP_CPU_GROUP: Optional[ProcessGroup] = None
_TP_PYNCCL_COMMUNICATOR = None
_TP_CA_COMMUNICATOR = None
# Pipeline model parallel group that the current rank belongs to.
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
_PP_CPU_GROUP: Optional[ProcessGroup] = None
_PP_PYNCCL_COMMUNICATOR = None
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None
# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None
# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PP_GLOBAL_RANKS: Optional[List[int]] = None
_LOCAL_RANK = -1
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
for key, value in tensor_dict.items():
assert "%" not in key, (
"Avoid having '%' in key "
"as it is used as a separator for nested entries.")
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(prefix + key, TensorMetadata(device, value.dtype,
value.size())))
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%")
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((prefix + key, value))
return metadata_list, tensor_list
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value
def get_pp_pynccl_communicator():
global _PP_PYNCCL_COMMUNICATOR
return _PP_PYNCCL_COMMUNICATOR
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
"""
def get_tp_pynccl_communicator(): # available attributes:
global _TP_PYNCCL_COMMUNICATOR rank: int # global rank
return _TP_PYNCCL_COMMUNICATOR ranks: List[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_message_queue_broadcaster: bool = False,
):
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
assert self.cpu_group is not None
assert self.device_group is not None
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
def get_tp_ca_communicator(): self.use_pynccl = use_pynccl
global _TP_CA_COMMUNICATOR self.use_custom_allreduce = use_custom_allreduce
return _TP_CA_COMMUNICATOR
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator]
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
else:
self.pynccl_comm = None
self.ca_comm: Optional[CustomAllreduce]
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
else:
self.ca_comm = None
from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
if use_message_queue_broadcaster and self.world_size > 1:
self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6)
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
return self.ranks[0]
@property
def last_rank(self):
"""Return the global rank of the last process in the group"""
return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group + 1) % world_size]
@property
def prev_rank(self):
"""Return the global rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group - 1) % world_size]
@contextmanager
def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None):
if graph_capture_context is None:
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
ca_comm = self.ca_comm
maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture()
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm = self.ca_comm
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_,
src=self.ranks[src],
group=self.device_group)
return input_
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.mq_broadcaster is not None:
assert src == 0, "Message queue broadcaster only supports src=0"
return self.mq_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list([obj],
src=self.ranks[src],
group=self.cpu_group)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=self.ranks[src],
group=self.cpu_group)
return recv[0]
def broadcast_object_list(self,
obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list,
src=self.ranks[src],
group=self.device_group)
return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")
# Send object size
torch.distributed.send(size_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank_in_group, (
"Invalid source rank. Source rank is the same as the current rank."
)
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=self.ranks[src],
group=self.cpu_group)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu")
rank_object = torch.distributed.recv(object_tensor,
src=self.ranks[src],
group=self.cpu_group)
assert rank_object == rank_size, (
"Received object sender rank does not match the size sender rank.")
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized() or self.world_size == 1):
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
return tensor_dict
def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
_WORLD: Optional[GroupCoordinator] = None
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, ("world group is not initialized")
return _WORLD
def init_world_group(ranks: List[int], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
)
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_message_queue_broadcaster=use_message_queue_broadcaster,
)
_TP: Optional[GroupCoordinator] = None
def get_tp_group() -> GroupCoordinator:
assert _TP is not None, ("tensor model parallel group is not initialized")
return _TP
# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group
_PP: Optional[GroupCoordinator] = None
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, (
"pipeline model parallel group is not initialized")
return _PP
def get_local_rank():
global _LOCAL_RANK # kept for backward compatibility
return _LOCAL_RANK get_pipeline_model_parallel_group = get_pp_group
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with get_tp_group().graph_capture() as context, get_pp_group(
).graph_capture(context):
yield context
logger = init_logger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment( def init_distributed_environment(
...@@ -100,31 +814,23 @@ def init_distributed_environment( ...@@ -100,31 +814,23 @@ def init_distributed_environment(
init_method=distributed_init_method, init_method=distributed_init_method,
world_size=world_size, world_size=world_size,
rank=rank) rank=rank)
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP # set the local rank
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size())) ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, _WORLD = init_world_group(ranks, local_rank, backend)
backend="gloo") else:
# set the local rank assert _WORLD.world_size == torch.distributed.get_world_size(), (
# local_rank is not available in torch ProcessGroup, "world group already initialized with a different world size")
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _LOCAL_RANK
_LOCAL_RANK = local_rank
# A small all_reduce for warmup.
data = torch.zeros(1)
if torch.cuda.is_available():
data = data.to(device=f"cuda:{local_rank}")
torch.distributed.all_reduce(data)
if torch.cuda.is_available():
torch.cuda.synchronize()
del data
def initialize_model_parallel( def initialize_model_parallel(
...@@ -157,8 +863,8 @@ def initialize_model_parallel( ...@@ -157,8 +863,8 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
# get the backend of _DEVICE_WORLD_GROUP backend = backend or torch.distributed.get_backend(
backend = backend or torch.distributed.get_backend() get_world_group().device_group)
if (world_size != if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size): tensor_model_parallel_size * pipeline_model_parallel_size):
...@@ -167,63 +873,39 @@ def initialize_model_parallel( ...@@ -167,63 +873,39 @@ def initialize_model_parallel(
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size // num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size) tensor_model_parallel_size)
num_pipeline_model_parallel_groups: int = (world_size // global _TP
pipeline_model_parallel_size) assert _TP is None, ("tensor model parallel group is already initialized")
rank = torch.distributed.get_rank() group_ranks = []
# Build the tensor model-parallel groups.
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
assert _TP_DEVICE_GROUP is None, (
"tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = list( ranks = list(
range(i * tensor_model_parallel_size, range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)) (i + 1) * tensor_model_parallel_size))
group = torch.distributed.new_group(ranks, backend=backend) group_ranks.append(ranks)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
if tensor_model_parallel_size > 1:
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)
# Initialize a custom fast all-reduce implementation. # message queue broadcaster is only used in tensor model parallel group
if _ENABLE_CUSTOM_ALL_REDUCE: _TP = init_model_parallel_group(group_ranks,
from vllm.distributed.device_communicators.custom_all_reduce import ( get_world_group().local_rank,
CustomAllreduce) backend,
_TP_CA_COMMUNICATOR = CustomAllreduce( use_message_queue_broadcaster=True)
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
global _PP_DEVICE_GROUP, _PP_CPU_GROUP num_pipeline_model_parallel_groups: int = (world_size //
global _PP_PYNCCL_COMMUNICATOR pipeline_model_parallel_size)
global _PP_GLOBAL_RANKS global _PP
assert _PP_DEVICE_GROUP is None, ( assert _PP is None, (
"pipeline model parallel group is already initialized") "pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group = torch.distributed.new_group(ranks, backend=backend) group_ranks.append(ranks)
cpu_group = torch.distributed.new_group(ranks, backend="gloo") # pipeline parallel does not need custom allreduce
if rank in ranks: _PP = init_model_parallel_group(group_ranks,
_PP_DEVICE_GROUP = group get_world_group().local_rank,
_PP_CPU_GROUP = cpu_group backend,
_PP_GLOBAL_RANKS = ranks use_custom_allreduce=False)
if pipeline_model_parallel_size > 1:
_PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_PP_CPU_GROUP,
device=_LOCAL_RANK,
)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
...@@ -235,8 +917,8 @@ def ensure_model_parallel_initialized( ...@@ -235,8 +917,8 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized. values if the model parallel groups are initialized.
""" """
# get the backend of _DEVICE_WORLD_GROUP backend = backend or torch.distributed.get_backend(
backend = backend or torch.distributed.get_backend() get_world_group().device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend) pipeline_model_parallel_size, backend)
...@@ -247,148 +929,87 @@ def ensure_model_parallel_initialized( ...@@ -247,148 +929,87 @@ def ensure_model_parallel_initialized(
), ("tensor parallel group already initialized, but of unexpected size: " ), ("tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. " f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}") f"{tensor_model_parallel_size=}")
assert (get_pipeline_model_parallel_world_size( pp_world_size = get_pp_group().world_size
) == pipeline_model_parallel_size), ( assert (pp_world_size == pipeline_model_parallel_size), (
"pipeline parallel group already initialized, but of unexpected size: " "pipeline parallel group already initialized, but of unexpected size: "
f"{get_pipeline_model_parallel_world_size()=} vs. " f"{pp_world_size=} vs. "
f"{pipeline_model_parallel_size=}") f"{pipeline_model_parallel_size=}")
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized.""" """Check if tensor and pipeline parallel groups are initialized."""
return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) return (_TP is not None and _PP is not None)
def get_cpu_world_group():
"""Get the CPU world group."""
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
return _CPU_WORLD_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP_DEVICE_GROUP is not None, (
"tensor model parallel group is not initialized")
return _TP_DEVICE_GROUP
_TP_STATE_PATCHED = False
def get_tensor_model_parallel_cpu_group():
"""Get the tensor model parallel cpu group the caller rank belongs to."""
assert _TP_CPU_GROUP is not None, (
"tensor model parallel cpu group is not initialized")
return _TP_CPU_GROUP
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
def get_pipeline_model_parallel_group(): This method is for draft workers of speculative decoding to run draft model
"""Get the pipeline model parallel group the caller rank belongs to.""" with different tp degree from that of target model workers.
assert _PP_DEVICE_GROUP is not None, (
"pipeline model parallel group is not initialized")
return _PP_DEVICE_GROUP
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
def get_pipeline_model_parallel_cpu_group(): _TP_STATE_PATCHED = True
"""Get the pipeline model parallel cpu group the caller rank belongs to.""" old_tp_group = get_tp_group()
assert _PP_CPU_GROUP is not None, ( global _TP
"pipeline model parallel cpu group is not initialized") _TP = tp_group
return _PP_CPU_GROUP try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tensor_model_parallel_world_size(): def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group.""" """Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size( return get_tp_group().world_size
group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
return torch.distributed.get_world_size(
group=get_pipeline_model_parallel_group())
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group.""" """Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) return get_tp_group().rank_in_group
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return torch.distributed.get_rank(
group=get_pipeline_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PP_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
return _PP_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PP_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PP_GLOBAL_RANKS[last_rank_local]
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
if _TP:
_TP.destroy()
_TP = None
def get_pipeline_model_parallel_next_rank(): global _PP
"""Return the global rank that follows the caller in the pipeline""" if _PP:
assert _PP_GLOBAL_RANKS is not None, ( _PP.destroy()
"Pipeline parallel group is not initialized") _PP = None
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank(): def destroy_distributed_environment():
"""Return the global rank that precedes the caller in the pipeline""" global _WORLD
assert _PP_GLOBAL_RANKS is not None, ( if _WORLD:
"Pipeline parallel group is not initialized") _WORLD.destroy()
rank_in_pipeline = get_pipeline_model_parallel_rank() _WORLD = None
world_size = get_pipeline_model_parallel_world_size() if torch.distributed.is_initialized():
return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] torch.distributed.destroy_process_group()
def destroy_model_parallel(): def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
"""Set the groups to none and destroy them."""
global _TP_DEVICE_GROUP
if _TP_DEVICE_GROUP:
torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
_TP_DEVICE_GROUP = None
global _TP_CPU_GROUP
if _TP_CPU_GROUP:
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
_TP_CPU_GROUP = None
global _TP_PYNCCL_COMMUNICATOR
_TP_PYNCCL_COMMUNICATOR = None
global _PP_DEVICE_GROUP
if _PP_DEVICE_GROUP:
torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
_PP_DEVICE_GROUP = None
global _PP_GLOBAL_RANKS
_PP_GLOBAL_RANKS = None
def is_in_the_same_node(pg: ProcessGroup):
""" """
This is a collective operation that checks if all processes in the group This is a collective operation that returns if each rank is in the same node
are in the same node. It tests if all processes are attached to the same as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory). memory system (shared access to shared memory).
""" """
assert torch.distributed.get_backend( assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, ( pg) != torch.distributed.Backend.NCCL, (
"is_in_the_same_node should be tested with a non-NCCL group.") "in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group # local rank inside the group
rank = torch.distributed.get_rank(group=pg) rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg) world_size = torch.distributed.get_world_size(group=pg)
...@@ -404,22 +1025,27 @@ def is_in_the_same_node(pg: ProcessGroup): ...@@ -404,22 +1025,27 @@ def is_in_the_same_node(pg: ProcessGroup):
try: try:
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
if rank == 0: if rank == source_rank:
# create a shared memory segment # create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128) shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name], torch.distributed.broadcast_object_list([shm.name],
src=ranks[0], src=ranks[source_rank],
group=pg) group=pg)
is_in_the_same_node[0] = 1 is_in_the_same_node[rank] = 1
else: else:
# try to open the shared memory segment # try to open the shared memory segment
recv = [None] recv = [None]
torch.distributed.broadcast_object_list(recv, torch.distributed.broadcast_object_list(recv,
src=ranks[0], src=ranks[source_rank],
group=pg) group=pg)
name = recv[0] name = recv[0]
shm = shared_memory.SharedMemory(name=name) # fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message: if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1 is_in_the_same_node[rank] = 1
except Exception as e: except Exception as e:
...@@ -432,14 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup): ...@@ -432,14 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup):
# clean up the shared memory segment # clean up the shared memory segment
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
if rank == 0: if rank == source_rank and shm:
if shm: shm.unlink()
shm.unlink()
else:
if shm:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker.unregister(
shm._name, "shared_memory") # type: ignore[attr-defined]
torch.distributed.all_reduce(is_in_the_same_node, group=pg) torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return is_in_the_same_node.sum().item() == world_size return [x == 1 for x in is_in_the_same_node.tolist()]
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence from typing import Sequence, Tuple
import torch import torch
...@@ -46,3 +46,19 @@ def split_tensor_along_last_dim( ...@@ -46,3 +46,19 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list return tensor_list
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
"""
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
return (start_layer, end_layer)
import argparse import argparse
import dataclasses import dataclasses
import json import json
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig,
TokenizerPoolConfig, VisionLanguageConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple from vllm.utils import FlexibleArgumentParser
def nullable_str(val: str): def nullable_str(val: str):
...@@ -66,6 +66,9 @@ class EngineArgs: ...@@ -66,6 +66,9 @@ class EngineArgs:
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None long_lora_scaling_factors: Optional[Tuple[float]] = None
...@@ -78,83 +81,31 @@ class EngineArgs: ...@@ -78,83 +81,31 @@ class EngineArgs:
model_loader_extra_config: Optional[dict] = None model_loader_extra_config: Optional[dict] = None
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
image_processor: Optional[str] = None
image_processor_revision: Optional[str] = None
disable_image_processor: bool = False
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False enable_chunked_prefill: bool = False
guided_decoding_backend: str = 'outlines' guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration. # Speculative decoding configuration.
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None qlora_adapter_name_or_path: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
self.tokenizer = self.model self.tokenizer = self.model
@staticmethod @staticmethod
def add_cli_args_for_vlm( def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
choices=[
t.name.lower()
for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM.'))
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=nullable_str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser.add_argument(
'--image-processor',
type=str,
default=EngineArgs.image_processor,
help='Name or path of the huggingface image processor to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--image-processor-revision',
type=str,
default=None,
help='Revision of the huggingface image processor version to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--disable-image-processor',
action='store_true',
help='Disables the use of image processor, even if one is defined '
'for the model on huggingface.')
return parser
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
...@@ -230,7 +181,7 @@ class EngineArgs: ...@@ -230,7 +181,7 @@ class EngineArgs:
'* "dummy" will initialize the weights with random values, ' '* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n' 'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from ' '* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples' 'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n' 'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes ' '* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n') 'quantization.\n')
...@@ -501,15 +452,26 @@ class EngineArgs: ...@@ -501,15 +452,26 @@ class EngineArgs:
'Enabling this will use the fully sharded layers. ' 'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or ' 'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.')) 'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu"], choices=[
"auto", "cuda", "neuron", "cpu", "openvino",
"tpu", "xpu"
],
help='Device type for vLLM execution.') help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser = EngineArgs.add_cli_args_for_vlm(parser)
parser.add_argument( parser.add_argument(
'--scheduler-delay-factor', '--scheduler-delay-factor',
type=float, type=float,
...@@ -534,6 +496,13 @@ class EngineArgs: ...@@ -534,6 +496,13 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens, default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.') 'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument( parser.add_argument(
'--speculative-max-model-len', '--speculative-max-model-len',
...@@ -564,6 +533,38 @@ class EngineArgs: ...@@ -564,6 +533,38 @@ class EngineArgs:
help='Min size of window for ngram prompt lookup in speculative ' help='Min size of window for ngram prompt lookup in speculative '
'decoding.') 'decoding.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
help='Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding. Defaults to 0.09')
parser.add_argument(
'--typical-acceptance-sampler-posterior-alpha',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
help='A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=nullable_str, type=nullable_str,
default=EngineArgs.model_loader_extra_config, default=EngineArgs.model_loader_extra_config,
...@@ -573,7 +574,7 @@ class EngineArgs: ...@@ -573,7 +574,7 @@ class EngineArgs:
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
'parsed into a dictionary.') 'parsed into a dictionary.')
parser.add_argument( parser.add_argument(
'--preemption_mode', '--preemption-mode',
type=str, type=str,
default=None, default=None,
help='If \'recompute\', the engine performs preemption by block ' help='If \'recompute\', the engine performs preemption by block '
...@@ -598,6 +599,13 @@ class EngineArgs: ...@@ -598,6 +599,13 @@ class EngineArgs:
type=str, type=str,
default=None, default=None,
help='Name or path of the QLoRA adapter.') help='Name or path of the QLoRA adapter.')
parser.add_argument(
'--otlp-traces-endpoint',
type=str,
default=None,
help='Target URL to which OpenTelemetry traces will be sent.')
return parser return parser
@classmethod @classmethod
...@@ -625,6 +633,7 @@ class EngineArgs: ...@@ -625,6 +633,7 @@ class EngineArgs:
raise ValueError( raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support " "BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}") f"'bitsandbytes' quantization, but got {self.quantization}")
multimodal_config = MultiModalConfig()
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = ModelConfig( model_config = ModelConfig(
...@@ -648,7 +657,8 @@ class EngineArgs: ...@@ -648,7 +657,8 @@ class EngineArgs:
max_logprobs=self.max_logprobs, max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name) served_model_name=self.served_model_name,
multimodal_config=multimodal_config)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
...@@ -676,6 +686,8 @@ class EngineArgs: ...@@ -676,6 +686,8 @@ class EngineArgs:
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
target_dtype=self.dtype, target_dtype=self.dtype,
speculative_model=self.speculative_model, speculative_model=self.speculative_model,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens, num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self. speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size, speculative_disable_by_batch_size,
...@@ -684,6 +696,12 @@ class EngineArgs: ...@@ -684,6 +696,12 @@ class EngineArgs:
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
...@@ -722,40 +740,17 @@ class EngineArgs: ...@@ -722,40 +740,17 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,
) )
if self.image_input_type: prompt_adapter_config = PromptAdapterConfig(
if (not self.image_token_id or not self.image_input_shape max_prompt_adapters=self.max_prompt_adapters,
or not self.image_feature_size): max_prompt_adapter_token=self.max_prompt_adapter_token) \
raise ValueError( if self.enable_prompt_adapter else None
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)
self.image_processor = None
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None
decoding_config = DecodingConfig( decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend) guided_decoding_backend=self.guided_decoding_backend)
observability_config = ObservabilityConfig(
otlp_traces_endpoint=self.otlp_traces_endpoint)
if (model_config.get_sliding_window() is not None if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled and scheduler_config.chunked_prefill_enabled
and not scheduler_config.use_v2_block_manager): and not scheduler_config.use_v2_block_manager):
...@@ -763,16 +758,20 @@ class EngineArgs: ...@@ -763,16 +758,20 @@ class EngineArgs:
"Chunked prefill is not supported with sliding window. " "Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.") "Set --disable-sliding-window to disable sliding window.")
return EngineConfig(model_config=model_config, return EngineConfig(
cache_config=cache_config, model_config=model_config,
parallel_config=parallel_config, cache_config=cache_config,
scheduler_config=scheduler_config, parallel_config=parallel_config,
device_config=device_config, scheduler_config=scheduler_config,
lora_config=lora_config, device_config=device_config,
vision_language_config=vision_language_config, lora_config=lora_config,
speculative_config=speculative_config, multimodal_config=multimodal_config,
load_config=load_config, speculative_config=speculative_config,
decoding_config=decoding_config) load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
)
@dataclass @dataclass
...@@ -783,8 +782,8 @@ class AsyncEngineArgs(EngineArgs): ...@@ -783,8 +782,8 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser, def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser: async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only: if not async_args_only:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', parser.add_argument('--engine-use-ray',
...@@ -805,13 +804,9 @@ class AsyncEngineArgs(EngineArgs): ...@@ -805,13 +804,9 @@ class AsyncEngineArgs(EngineArgs):
# These functions are used by sphinx to build the documentation # These functions are used by sphinx to build the documentation
def _engine_args_parser(): def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser()) return EngineArgs.add_cli_args(FlexibleArgumentParser())
def _async_engine_args_parser(): def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(), return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True) async_args_only=True)
def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
...@@ -10,6 +10,7 @@ import vllm.envs as envs ...@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs from vllm.inputs import LLMInputs, PromptInputs
...@@ -17,6 +18,7 @@ from vllm.logger import init_logger ...@@ -17,6 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -210,7 +212,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -210,7 +212,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
async def step_async( async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
...@@ -220,18 +223,22 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -220,18 +223,22 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size, running_queue_size=scheduler_outputs.running_queue_size,
) finished_requests_ids=finished_requests_ids)
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
execute_model_req) execute_model_req)
else: else:
...@@ -244,21 +251,21 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -244,21 +251,21 @@ class _AsyncLLMEngine(LLMEngine):
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs: # Tracing
# Stop the execute model loop in parallel workers until there are self.do_tracing(scheduler_outputs)
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs return request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async( async def process_model_inputs_async(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
if isinstance(inputs, str): if isinstance(inputs, str):
inputs = {"prompt": inputs} inputs = {"prompt": inputs}
...@@ -274,17 +281,27 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -274,17 +281,27 @@ class _AsyncLLMEngine(LLMEngine):
else: else:
prompt_token_ids = inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids, if prompt_adapter_request:
prompt=inputs.get("prompt"), prompt_token_ids = [
multi_modal_data=inputs.get("multi_modal_data")) 0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
...@@ -293,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -293,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time() arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async( processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request) request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -301,9 +321,13 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -301,9 +321,13 @@ class _AsyncLLMEngine(LLMEngine):
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
) )
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
...@@ -369,17 +393,42 @@ class AsyncLLMEngine: ...@@ -369,17 +393,42 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, ( assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.") "Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray": elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config) initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
...@@ -461,7 +510,8 @@ class AsyncLLMEngine: ...@@ -461,7 +510,8 @@ class AsyncLLMEngine:
# order of the arguments. # order of the arguments.
cache_config = kwargs["cache_config"] cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"] parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1: if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization num_gpus = cache_config.gpu_memory_utilization
else: else:
num_gpus = 1 num_gpus = 1
...@@ -469,7 +519,7 @@ class AsyncLLMEngine: ...@@ -469,7 +519,7 @@ class AsyncLLMEngine:
self._engine_class).remote self._engine_class).remote
return engine_class(*args, **kwargs) return engine_class(*args, **kwargs)
async def engine_step(self) -> bool: async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests. """Kick the engine to process the waiting requests.
Returns True if there are in-progress requests.""" Returns True if there are in-progress requests."""
...@@ -500,14 +550,16 @@ class AsyncLLMEngine: ...@@ -500,14 +550,16 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore request_outputs = await self.engine.step.remote() # type: ignore
else: else:
request_outputs = await self.engine.step_async() request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
finished = True
for request_output in request_outputs: for request_output in request_outputs:
self._request_tracker.process_request_output( self._request_tracker.process_request_output(
request_output, verbose=self.log_requests) request_output, verbose=self.log_requests)
finished = finished and request_output.finished
return len(request_outputs) > 0 return not finished
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
...@@ -516,18 +568,65 @@ class AsyncLLMEngine: ...@@ -516,18 +568,65 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
has_requests_in_progress = False if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True: while True:
if not has_requests_in_progress: if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...") logger.debug("Waiting for new requests...")
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!") logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
# Abort if iteration takes too long due to unrecoverable errors # Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts). # (eg. NCCL timeouts).
try: try:
has_requests_in_progress = await asyncio.wait_for( async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S) done, _ = await asyncio.wait(
requests_in_progress,
return_when=asyncio.FIRST_COMPLETED)
for _ in range(pipeline_parallel_size):
await asyncio.sleep(0)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc: except asyncio.TimeoutError as exc:
logger.error( logger.error(
"Engine iteration timed out. This should never happen!") "Engine iteration timed out. This should never happen!")
...@@ -542,6 +641,8 @@ class AsyncLLMEngine: ...@@ -542,6 +641,8 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
if isinstance(inputs, str): if isinstance(inputs, str):
...@@ -577,25 +678,14 @@ class AsyncLLMEngine: ...@@ -577,25 +678,14 @@ class AsyncLLMEngine:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
if self.engine_use_ray:
processed_inputs = await self.engine.process_model_inputs_async \
.remote( # type: ignore
request_id=request_id,
inputs=inputs,
lora_request=lora_request)
else:
processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id,
inputs=inputs,
lora_request=lora_request)
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
inputs=processed_inputs, inputs=inputs,
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
) trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
return stream return stream
...@@ -605,6 +695,8 @@ class AsyncLLMEngine: ...@@ -605,6 +695,8 @@ class AsyncLLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -619,6 +711,9 @@ class AsyncLLMEngine: ...@@ -619,6 +711,9 @@ class AsyncLLMEngine:
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine The output `RequestOutput` objects from the LLMEngine
...@@ -672,6 +767,8 @@ class AsyncLLMEngine: ...@@ -672,6 +767,8 @@ class AsyncLLMEngine:
inputs, inputs,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
): ):
yield LLMEngine.validate_output(output, RequestOutput) yield LLMEngine.validate_output(output, RequestOutput)
...@@ -681,6 +778,7 @@ class AsyncLLMEngine: ...@@ -681,6 +778,7 @@ class AsyncLLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]: ) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
...@@ -695,6 +793,7 @@ class AsyncLLMEngine: ...@@ -695,6 +793,7 @@ class AsyncLLMEngine:
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
...@@ -746,6 +845,7 @@ class AsyncLLMEngine: ...@@ -746,6 +845,7 @@ class AsyncLLMEngine:
inputs, inputs,
pooling_params, pooling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
): ):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput) yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
...@@ -756,6 +856,8 @@ class AsyncLLMEngine: ...@@ -756,6 +856,8 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
*, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or """Common logic to process requests with SamplingParams or
PoolingParams.""" PoolingParams."""
...@@ -767,6 +869,8 @@ class AsyncLLMEngine: ...@@ -767,6 +869,8 @@ class AsyncLLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
) )
try: try:
...@@ -846,3 +950,10 @@ class AsyncLLMEngine: ...@@ -846,3 +950,10 @@ class AsyncLLMEngine:
else: else:
await self.engine.check_health_async() await self.engine.check_health_async()
logger.debug("Health check took %fs", time.perf_counter() - t) logger.debug("Health check took %fs", time.perf_counter() - t)
async def is_tracing_enabled(self) -> bool:
if self.engine_use_ray:
return await self.engine.is_tracing_enabled.remote( # type: ignore
)
else:
return self.engine.is_tracing_enabled()
# Workaround for https://github.com/python/cpython/issues/86296
#
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
# Licensed under the Apache License (Apache-2.0)
import asyncio
import enum
import sys
import warnings
from types import TracebackType
from typing import Any, Optional, Type
if sys.version_info[:2] >= (3, 11):
from asyncio import timeout as asyncio_timeout
else:
def asyncio_timeout(delay: Optional[float]) -> "Timeout":
"""timeout context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with timeout(0.001):
... async with aiohttp.get('https://github.com') as r:
... await r.text()
delay - value in seconds or None to disable timeout logic
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + delay if delay is not None else None
return Timeout(deadline, loop)
class _State(enum.Enum):
INIT = "INIT"
ENTER = "ENTER"
TIMEOUT = "TIMEOUT"
EXIT = "EXIT"
class Timeout:
# Internal class, please don't instantiate it directly
# Use timeout() and timeout_at() public factories instead.
#
# Implementation note: `async with timeout()` is preferred
# over `with timeout()`.
# While technically the Timeout class implementation
# doesn't need to be async at all,
# the `async with` statement explicitly points that
# the context manager should be used from async function context.
#
# This design allows to avoid many silly misusages.
#
# TimeoutError is raised immediately when scheduled
# if the deadline is passed.
# The purpose is to time out as soon as possible
# without waiting for the next await expression.
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
def __init__(self, deadline: Optional[float],
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._state = _State.INIT
self._timeout_handler = None # type: Optional[asyncio.Handle]
if deadline is None:
self._deadline = None # type: Optional[float]
else:
self.update(deadline)
def __enter__(self) -> "Timeout":
warnings.warn(
"with timeout() is deprecated, use async with timeout()",
DeprecationWarning,
stacklevel=2,
)
self._do_enter()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._do_exit(exc_type)
return None
async def __aenter__(self) -> "Timeout":
self._do_enter()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
self._do_exit(exc_type)
return None
@property
def expired(self) -> bool:
"""Is timeout expired during execution?"""
return self._state == _State.TIMEOUT
@property
def deadline(self) -> Optional[float]:
return self._deadline
def reject(self) -> None:
"""Reject scheduled timeout if any."""
# cancel is maybe better name but
# task.cancel() raises CancelledError in asyncio world.
if self._state not in (_State.INIT, _State.ENTER):
raise RuntimeError(f"invalid state {self._state.value}")
self._reject()
def _reject(self) -> None:
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._timeout_handler = None
def shift(self, delay: float) -> None:
"""Advance timeout on delay seconds.
The delay can be negative.
Raise RuntimeError if shift is called when deadline is not scheduled
"""
deadline = self._deadline
if deadline is None:
raise RuntimeError(
"cannot shift timeout if deadline is not scheduled")
self.update(deadline + delay)
def update(self, deadline: float) -> None:
"""Set deadline to absolute value.
deadline argument points on the time in the same clock system
as loop.time().
If new deadline is in the past the timeout is raised immediately.
Please note: it is not POSIX time but a time with
undefined starting base, e.g. the time of the system power on.
"""
if self._state == _State.EXIT:
raise RuntimeError(
"cannot reschedule after exit from context manager")
if self._state == _State.TIMEOUT:
raise RuntimeError("cannot reschedule expired timeout")
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._deadline = deadline
if self._state != _State.INIT:
self._reschedule()
def _reschedule(self) -> None:
assert self._state == _State.ENTER
deadline = self._deadline
if deadline is None:
return
now = self._loop.time()
if self._timeout_handler is not None:
self._timeout_handler.cancel()
task = asyncio.current_task()
if deadline <= now:
self._timeout_handler = self._loop.call_soon(
self._on_timeout, task)
else:
self._timeout_handler = self._loop.call_at(
deadline, self._on_timeout, task)
def _do_enter(self) -> None:
if self._state != _State.INIT:
raise RuntimeError(f"invalid state {self._state.value}")
self._state = _State.ENTER
self._reschedule()
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
if exc_type is asyncio.CancelledError and \
self._state == _State.TIMEOUT:
self._timeout_handler = None
raise asyncio.TimeoutError
# timeout has not expired
self._state = _State.EXIT
self._reject()
return None
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
if task:
task.cancel()
self._state = _State.TIMEOUT
# drop the reference early
self._timeout_handler = None
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ParallelConfig, LoRAConfig, ModelConfig, MultiModalConfig,
SchedulerConfig, SpeculativeConfig, ObservabilityConfig, ParallelConfig,
VisionLanguageConfig) PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs) SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase, Stats)
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence, PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter from vllm.utils import Counter
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig): def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
try: config = try_get_generation_config(
return GenerationConfig.from_pretrained( model_config.model,
model_config.model, trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision, revision=model_config.revision,
).to_diff_dict() )
except OSError:
# Not found. if config is None:
return {} return {}
return config.to_diff_dict()
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
...@@ -81,12 +89,14 @@ class LLMEngine: ...@@ -81,12 +89,14 @@ class LLMEngine:
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device. device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA. lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision multimodal_config (Optional): The configuration related to multimodal
language models. models.
speculative_config (Optional): The configuration related to speculative speculative_config (Optional): The configuration related to speculative
decoding. decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection. usage_context: Specified entry point, used for usage info collection.
""" """
...@@ -151,12 +161,15 @@ class LLMEngine: ...@@ -151,12 +161,15 @@ class LLMEngine:
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig], decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
...@@ -165,11 +178,14 @@ class LLMEngine: ...@@ -165,11 +178,14 @@ class LLMEngine:
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, " "enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)", "decoding_config=%r, observability_config=%r, "
vllm.__version__, "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)",
VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
model_config.tokenizer, model_config.tokenizer,
...@@ -185,6 +201,7 @@ class LLMEngine: ...@@ -185,6 +201,7 @@ class LLMEngine:
load_config.download_dir, load_config.download_dir,
load_config.load_format, load_config.load_format,
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce, parallel_config.disable_custom_all_reduce,
model_config.quantization, model_config.quantization,
model_config.enforce_eager, model_config.enforce_eager,
...@@ -192,21 +209,27 @@ class LLMEngine: ...@@ -192,21 +209,27 @@ class LLMEngine:
model_config.quantization_param_path, model_config.quantization_param_path,
device_config.device, device_config.device,
decoding_config, decoding_config,
observability_config,
model_config.seed, model_config.seed,
model_config.served_model_name, model_config.served_model_name,
scheduler_config.use_v2_block_manager,
cache_config.enable_prefix_caching,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.vision_language_config = vision_language_config self.multimodal_config = multimodal_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.load_config = load_config self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig() self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
...@@ -220,6 +243,9 @@ class LLMEngine: ...@@ -220,6 +243,9 @@ class LLMEngine:
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
model_config) model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
self.model_executor = executor_class( self.model_executor = executor_class(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
...@@ -227,9 +253,10 @@ class LLMEngine: ...@@ -227,9 +253,10 @@ class LLMEngine:
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, multimodal_config=multimodal_config,
speculative_config=speculative_config, speculative_config=speculative_config,
load_config=load_config, load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
) )
if not self.model_config.embedding_mode: if not self.model_config.embedding_mode:
...@@ -257,11 +284,13 @@ class LLMEngine: ...@@ -257,11 +284,13 @@ class LLMEngine:
"quantization": "quantization":
model_config.quantization, model_config.quantization,
"kv_cache_dtype": "kv_cache_dtype":
cache_config.cache_dtype, str(cache_config.cache_dtype),
# Feature flags # Feature flags
"enable_lora": "enable_lora":
bool(lora_config), bool(lora_config),
"enable_prompt_adapter":
bool(prompt_adapter_config),
"enable_prefix_caching": "enable_prefix_caching":
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
"enforce_eager": "enforce_eager":
...@@ -278,15 +307,35 @@ class LLMEngine: ...@@ -278,15 +307,35 @@ class LLMEngine:
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging. # Metric Logging.
if self.log_stats: if self.log_stats:
self.stat_logger = StatLogger( if stat_loggers is not None:
local_interval=_LOCAL_LOGGING_INTERVAL_SEC, self.stat_loggers = stat_loggers
labels=dict(model_name=model_config.served_model_name), else:
max_model_len=self.model_config.max_model_len) self.stat_loggers = {
self.stat_logger.info("cache_config", self.cache_config) "logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or # Create sequence output processor, e.g. for beam search or
# speculative decoding. # speculative decoding.
...@@ -336,14 +385,27 @@ class LLMEngine: ...@@ -336,14 +385,27 @@ class LLMEngine:
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor executor_class = CPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
elif distributed_executor_backend == "ray": elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config) initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor
...@@ -355,7 +417,6 @@ class LLMEngine: ...@@ -355,7 +417,6 @@ class LLMEngine:
else: else:
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor executor_class = GPUExecutor
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
**engine_config.to_dict(), **engine_config.to_dict(),
...@@ -416,6 +477,9 @@ class LLMEngine: ...@@ -416,6 +477,9 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _get_eos_token_id( def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]: self, lora_request: Optional[LoRARequest]) -> Optional[int]:
...@@ -433,6 +497,8 @@ class LLMEngine: ...@@ -433,6 +497,8 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None,
) -> None: ) -> None:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
...@@ -440,7 +506,7 @@ class LLMEngine: ...@@ -440,7 +506,7 @@ class LLMEngine:
eos_token_id = self._get_eos_token_id(lora_request) eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request) lora_request, prompt_adapter_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
...@@ -450,7 +516,8 @@ class LLMEngine: ...@@ -450,7 +516,8 @@ class LLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
) trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
...@@ -458,19 +525,28 @@ class LLMEngine: ...@@ -458,19 +525,28 @@ class LLMEngine:
params, params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
) prompt_adapter_request=prompt_adapter_request)
else: else:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler with least unfinished seqs.
self.scheduler.add_seq_group(seq_group) costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
def process_model_inputs( def process_model_inputs(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
if isinstance(inputs, str): if isinstance(inputs, str):
inputs = {"prompt": inputs} inputs = {"prompt": inputs}
...@@ -485,9 +561,16 @@ class LLMEngine: ...@@ -485,9 +561,16 @@ class LLMEngine:
else: else:
prompt_token_ids = inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids, if prompt_adapter_request:
prompt=inputs.get("prompt"), prompt_token_ids = \
multi_modal_data=inputs.get("multi_modal_data")) [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
return self.input_processor(llm_inputs)
def add_request( def add_request(
self, self,
...@@ -496,6 +579,8 @@ class LLMEngine: ...@@ -496,6 +579,8 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -513,6 +598,7 @@ class LLMEngine: ...@@ -513,6 +598,7 @@ class LLMEngine:
:class:`~vllm.PoolingParams` for pooling. :class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
trace_headers: OpenTelemetry trace headers.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
...@@ -544,9 +630,11 @@ class LLMEngine: ...@@ -544,9 +630,11 @@ class LLMEngine:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
processed_inputs = self.process_model_inputs(request_id=request_id, processed_inputs = self.process_model_inputs(
inputs=inputs, request_id=request_id,
lora_request=lora_request) inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -554,6 +642,8 @@ class LLMEngine: ...@@ -554,6 +642,8 @@ class LLMEngine:
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
) )
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
...@@ -563,6 +653,8 @@ class LLMEngine: ...@@ -563,6 +653,8 @@ class LLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
...@@ -576,19 +668,19 @@ class LLMEngine: ...@@ -576,19 +668,19 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields) self.generation_config_fields, seq.eos_token_id)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(
seqs=[seq], request_id=request_id,
arrival_time=arrival_time, seqs=[seq],
sampling_params=sampling_params, arrival_time=arrival_time,
lora_request=lora_request) sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
return seq_group return seq_group
...@@ -599,16 +691,19 @@ class LLMEngine: ...@@ -599,16 +691,19 @@ class LLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone() pooling_params = pooling_params.clone()
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(
seqs=[seq], request_id=request_id,
arrival_time=arrival_time, seqs=[seq],
lora_request=lora_request, arrival_time=arrival_time,
pooling_params=pooling_params) lora_request=lora_request,
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request)
return seq_group return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
...@@ -628,7 +723,8 @@ class LLMEngine: ...@@ -628,7 +723,8 @@ class LLMEngine:
>>> # abort the request >>> # abort the request
>>> engine.abort_request(request_id) >>> engine.abort_request(request_id)
""" """
self.scheduler.abort_seq_group(request_id) for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig: def get_model_config(self) -> ModelConfig:
"""Gets the model configuration.""" """Gets the model configuration."""
...@@ -640,11 +736,20 @@ class LLMEngine: ...@@ -640,11 +736,20 @@ class LLMEngine:
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return sum(scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler)
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return any(scheduler.has_unfinished_seqs()
for scheduler in self.scheduler)
def has_unfinished_requests_for_virtual_engine(
self, virtual_engine: int) -> bool:
"""
Returns True if there are unfinished requests for the virtual engine.
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def _process_sequence_group_outputs( def _process_sequence_group_outputs(
self, self,
...@@ -693,7 +798,8 @@ class LLMEngine: ...@@ -693,7 +798,8 @@ class LLMEngine:
self.output_processor.process_outputs(seq_group, outputs) self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[Union[RequestOutput, request_outputs: List[Union[RequestOutput,
...@@ -759,9 +865,16 @@ class LLMEngine: ...@@ -759,9 +865,16 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs): >>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break >>> break
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
...@@ -769,7 +882,7 @@ class LLMEngine: ...@@ -769,7 +882,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size, running_queue_size=scheduler_outputs.running_queue_size,
) finished_requests_ids=finished_requests_ids)
output = self.model_executor.execute_model( output = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
else: else:
...@@ -782,7 +895,10 @@ class LLMEngine: ...@@ -782,7 +895,10 @@ class LLMEngine:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs: # Tracing
self.do_tracing(scheduler_outputs)
if not self.has_unfinished_requests():
# Stop the execute model loop in parallel workers until there are # Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in # more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks # torch.distributed ops which may otherwise timeout, and unblocks
...@@ -792,14 +908,24 @@ class LLMEngine: ...@@ -792,14 +908,24 @@ class LLMEngine:
return request_outputs return request_outputs
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if logger_name in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} already exists.")
self.stat_loggers[logger_name] = logger
def remove_logger(self, logger_name: str) -> None:
if logger_name not in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} does not exist.")
del self.stat_loggers[logger_name]
def do_log_stats( def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None: model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active.""" """Forced log when no requests active."""
if self.log_stats: if self.log_stats:
self.stat_logger.log( for logger in self.stat_loggers.values():
self._get_stats(scheduler_outputs, model_output)) logger.log(self._get_stats(scheduler_outputs, model_output))
def _get_stats( def _get_stats(
self, self,
...@@ -817,23 +943,28 @@ class LLMEngine: ...@@ -817,23 +943,28 @@ class LLMEngine:
# System State # System State
# Scheduler State # Scheduler State
num_running_sys = len(self.scheduler.running) num_running_sys = sum(
num_swapped_sys = len(self.scheduler.swapped) len(scheduler.running) for scheduler in self.scheduler)
num_waiting_sys = len(self.scheduler.waiting) num_swapped_sys = sum(
len(scheduler.swapped) for scheduler in self.scheduler)
num_waiting_sys = sum(
len(scheduler.waiting) for scheduler in self.scheduler)
# KV Cache Usage in % # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
gpu_cache_usage_sys = 0. gpu_cache_usage_sys = 0.
if num_total_gpu is not None: if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( num_free_gpu = sum(
) scheduler.block_manager.get_num_free_gpu_blocks()
for scheduler in self.scheduler)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0. cpu_cache_usage_sys = 0.
if num_total_cpu is not None and num_total_cpu > 0: if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( num_free_cpu = sum(
) scheduler.block_manager.get_num_free_cpu_blocks()
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Iteration stats # Iteration stats
...@@ -970,8 +1101,82 @@ class LLMEngine: ...@@ -970,8 +1101,82 @@ class LLMEngine:
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id) return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.model_executor.list_loras() return self.model_executor.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
def is_tracing_enabled(self) -> bool:
return self.tracer is not None
def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
if self.tracer is None:
return
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
self.create_trace_span(seq_group)
def create_trace_span(self, seq_group: SequenceGroup) -> None:
if self.tracer is None or seq_group.sampling_params is None:
return
arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)
trace_context = extract_trace_context(seq_group.trace_headers)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds) as seq_span:
metrics = seq_group.metrics
ttft = metrics.first_token_time - metrics.arrival_time
e2e_time = metrics.finished_time - metrics.arrival_time
# attribute names are based on
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
self.model_config.model)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
seq_group.request_id)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
seq_group.sampling_params.temperature)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
seq_group.num_seqs())
seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
len(seq_group.prompt_token_ids))
seq_span.set_attribute(
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
sum([
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
]))
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
metrics.time_in_queue)
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
import time import time
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union from typing import Dict, List, Optional, Protocol, Union
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, import prometheus_client
disable_created_metrics)
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger from vllm.logger import init_logger
if ray is not None:
from ray.util import metrics as ray_metrics
else:
ray_metrics = None
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
disable_created_metrics() prometheus_client.disable_created_metrics()
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions. # to extract the metrics definitions.
...@@ -24,56 +30,55 @@ disable_created_metrics() ...@@ -24,56 +30,55 @@ disable_created_metrics()
# begin-metrics-definitions # begin-metrics-definitions
class Metrics: class Metrics:
labelname_finish_reason = "finished_reason" labelname_finish_reason = "finished_reason"
_base_library = prometheus_client
def __init__(self, labelnames: List[str], max_model_len: int): def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors # Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names): self._unregister_vllm_metrics()
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)
# Config Information # Config Information
self.info_cache_config = Info( self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config', name='vllm:cache_config',
documentation='information of cache_config') documentation='information of cache_config')
# System stats # System stats
# Scheduler State # Scheduler State
self.gauge_scheduler_running = Gauge( self.gauge_scheduler_running = self._base_library.Gauge(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.", documentation="Number of requests currently running on GPU.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge( self.gauge_scheduler_waiting = self._base_library.Gauge(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge( self.gauge_scheduler_swapped = self._base_library.Gauge(
name="vllm:num_requests_swapped", name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.", documentation="Number of requests swapped to CPU.",
labelnames=labelnames) labelnames=labelnames)
# KV Cache Usage in % # KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge( self.gauge_gpu_cache_usage = self._base_library.Gauge(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge( self.gauge_cpu_cache_usage = self._base_library.Gauge(
name="vllm:cpu_cache_usage_perc", name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.", documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames) labelnames=labelnames)
# Iteration stats # Iteration stats
self.counter_num_preemption = Counter( self.counter_num_preemption = self._base_library.Counter(
name="vllm:num_preemptions_total", name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.", documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames) labelnames=labelnames)
self.counter_prompt_tokens = Counter( self.counter_prompt_tokens = self._base_library.Counter(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames) labelnames=labelnames)
self.counter_generation_tokens = Counter( self.counter_generation_tokens = self._base_library.Counter(
name="vllm:generation_tokens_total", name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames)
self.histogram_time_to_first_token = Histogram( self.histogram_time_to_first_token = self._base_library.Histogram(
name="vllm:time_to_first_token_seconds", name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
labelnames=labelnames, labelnames=labelnames,
...@@ -81,7 +86,7 @@ class Metrics: ...@@ -81,7 +86,7 @@ class Metrics:
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0 0.75, 1.0, 2.5, 5.0, 7.5, 10.0
]) ])
self.histogram_time_per_output_token = Histogram( self.histogram_time_per_output_token = self._base_library.Histogram(
name="vllm:time_per_output_token_seconds", name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.", documentation="Histogram of time per output token in seconds.",
labelnames=labelnames, labelnames=labelnames,
...@@ -92,59 +97,82 @@ class Metrics: ...@@ -92,59 +97,82 @@ class Metrics:
# Request stats # Request stats
# Latency # Latency
self.histogram_e2e_time_request = Histogram( self.histogram_e2e_time_request = self._base_library.Histogram(
name="vllm:e2e_request_latency_seconds", name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.", documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata # Metadata
self.histogram_num_prompt_tokens_request = Histogram( self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
name="vllm:request_prompt_tokens", name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames, labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
) )
self.histogram_num_generation_tokens_request = Histogram( self.histogram_num_generation_tokens_request = \
name="vllm:request_generation_tokens", self._base_library.Histogram(
documentation="Number of generation tokens processed.", name="vllm:request_generation_tokens",
labelnames=labelnames, documentation="Number of generation tokens processed.",
buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames,
) buckets=build_1_2_5_buckets(max_model_len),
self.histogram_best_of_request = Histogram( )
self.histogram_best_of_request = self._base_library.Histogram(
name="vllm:request_params_best_of", name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.", documentation="Histogram of the best_of request parameter.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
) )
self.histogram_n_request = Histogram( self.histogram_n_request = self._base_library.Histogram(
name="vllm:request_params_n", name="vllm:request_params_n",
documentation="Histogram of the n request parameter.", documentation="Histogram of the n request parameter.",
labelnames=labelnames, labelnames=labelnames,
buckets=[1, 2, 5, 10, 20], buckets=[1, 2, 5, 10, 20],
) )
self.counter_request_success = Counter( self.counter_request_success = self._base_library.Counter(
name="vllm:request_success_total", name="vllm:request_success_total",
documentation="Count of successfully processed requests.", documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason]) labelnames=labelnames + [Metrics.labelname_finish_reason])
# Deprecated in favor of vllm:prompt_tokens_total # Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge( self.gauge_avg_prompt_throughput = self._base_library.Gauge(
name="vllm:avg_prompt_throughput_toks_per_s", name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.", documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
# Deprecated in favor of vllm:generation_tokens_total # Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge( self.gauge_avg_generation_throughput = self._base_library.Gauge(
name="vllm:avg_generation_throughput_toks_per_s", name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.", documentation="Average generation throughput in tokens/s.",
labelnames=labelnames, labelnames=labelnames,
) )
def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector)
class RayMetrics(Metrics):
"""
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library = ray_metrics
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, max_model_len)
def _unregister_vllm_metrics(self) -> None:
# No-op on purpose
pass
# end-metrics-definitions # end-metrics-definitions
def build_1_2_5_buckets(max_value: int): def build_1_2_5_buckets(max_value: int) -> List[int]:
""" """
Builds a list of buckets with increasing powers of 10 multiplied by Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum. mantissa values (1, 2, 5) until the value exceeds the specified maximum.
...@@ -155,7 +183,7 @@ def build_1_2_5_buckets(max_value: int): ...@@ -155,7 +183,7 @@ def build_1_2_5_buckets(max_value: int):
""" """
mantissa_lst = [1, 2, 5] mantissa_lst = [1, 2, 5]
exponent = 0 exponent = 0
buckets = [] buckets: List[int] = []
while True: while True:
for m in mantissa_lst: for m in mantissa_lst:
value = m * 10**exponent value = m * 10**exponent
...@@ -206,34 +234,136 @@ class SupportsMetricsInfo(Protocol): ...@@ -206,34 +234,136 @@ class SupportsMetricsInfo(Protocol):
... ...
class StatLogger: def local_interval_elapsed(now: float, last_log: float,
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" local_interval: float) -> bool:
elapsed_time = now - last_log
return elapsed_time > local_interval
def get_throughput(tracked_stats: List[int], now: float,
last_log: float) -> float:
return float(np.sum(tracked_stats) / (now - last_log))
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
# Metadata for logging locally.
self.last_local_log = time.time()
self.local_interval = local_interval
class StatLoggerBase(ABC):
"""Base class for StatLogger."""
def __init__(self, local_interval: float) -> None:
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
raise NotImplementedError
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
class PrometheusStatLogger(StatLoggerBase):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls = Metrics
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
super().__init__(local_interval)
# Prometheus metrics # Prometheus metrics
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()), self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
max_model_len=max_model_len) max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info()) self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float: def _log_gauge(self, gauge, data: Union[int, float]) -> None:
return float(np.sum(tracked_stats) / (now - self.last_local_log)) # Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _local_interval_elapsed(self, now: float) -> bool: def _log_counter(self, counter, data: Union[int, float]) -> None:
elapsed_time = now - self.last_local_log # Convenience function for logging to counter.
return elapsed_time > self.local_interval counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int],
List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus(self, stats: Stats) -> None: def _log_prometheus(self, stats: Stats) -> None:
# System state data # System state data
...@@ -279,26 +409,6 @@ class StatLogger: ...@@ -279,26 +409,6 @@ class StatLogger:
self._log_histogram(self.metrics.histogram_best_of_request, self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests) stats.best_of_requests)
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram: Histogram,
data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval. # Logs metrics to prometheus that are computed every logging_interval.
...@@ -313,11 +423,8 @@ class StatLogger: ...@@ -313,11 +423,8 @@ class StatLogger:
self.metrics.gauge_avg_generation_throughput.labels( self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput) **self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None: def log(self, stats: Stats):
"""Called by LLMEngine. """Logs to prometheus and tracked stats every iteration."""
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus. # Log to prometheus.
self._log_prometheus(stats) self._log_prometheus(stats)
...@@ -326,50 +433,28 @@ class StatLogger: ...@@ -326,50 +433,28 @@ class StatLogger:
self.num_generation_tokens.append(stats.num_generation_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds. # Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now): if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them # Compute summary metrics for tracked stats (and log them
# to promethus if applicable). # to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens, prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now) now=stats.now,
generation_throughput = self._get_throughput( last_log=self.last_local_log)
self.num_generation_tokens, now=stats.now) generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
self._log_prometheus_interval( self._log_prometheus_interval(
prompt_throughput=prompt_throughput, prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput) generation_throughput=generation_throughput)
# Log to stdout.
logger.info(
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
"CPU KV cache usage: %.1f%%.",
prompt_throughput,
generation_throughput,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval. # Reset tracked stats for next interval.
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: " class RayPrometheusStatLogger(PrometheusStatLogger):
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " """RayPrometheusStatLogger uses Ray metrics instead."""
f"System efficiency: {metrics.system_efficiency:.3f}, " _metrics_cls = RayMetrics
f"Number of speculative tokens: {metrics.num_spec_tokens}, " \ No newline at end of file
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
...@@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC): ...@@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC):
def create_output_processor( def create_output_processor(
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
......
...@@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def __init__( def __init__(
self, self,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
...@@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
break break
if seq.is_finished(): if seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
...@@ -60,14 +60,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -60,14 +60,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
assert len(outputs) == 1, ("Single step should only has 1 output.") assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0] output = outputs[0]
prompt_logprobs = output.prompt_logprobs prompt_logprobs = output.prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None: if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []
if seq_group.sampling_params.detokenize and self.detokenizer: if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace( self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs) seq_group,
if not seq_group.prompt_logprobs: prompt_logprobs,
# The first prompt token's logprob is None because it doesn't position_offset=len(seq_group.prompt_logprobs))
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group.prompt_logprobs.extend(prompt_logprobs) seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
...@@ -95,7 +104,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -95,7 +104,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# not be used in the future iterations. # not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id) seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent) for scheduler in self.scheduler:
scheduler.free_seq(parent)
continue continue
# Fork the parent sequence if there are multiple child samples. # Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]: for child_sample in child_samples[:-1]:
...@@ -133,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -133,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
self.scheduler.fork_seq(parent, seq) for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
...@@ -141,13 +152,14 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -141,13 +152,14 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# old sequences. # old sequences.
for seq, parent in child_seqs: for seq, parent in child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
return return
# Beam search case # Beam search case
# Select the child sequences to keep in the sequence group. # Select the child sequences to keep in the sequence group.
selected_child_seqs = [] selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs = [] unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = seq_group.sampling_params.best_of beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty length_penalty = seq_group.sampling_params.length_penalty
...@@ -226,13 +238,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -226,13 +238,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
self.scheduler.fork_seq(parent, seq) for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs: for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and # Remove the unselected parent sequences from the sequence group and
# free their memory in block manager. # free their memory in block manager.
...@@ -241,7 +255,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -241,7 +255,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Remove the parent sequence if it is not selected for next # Remove the parent sequence if it is not selected for next
# iteration # iteration
seq_group.remove(seq.seq_id) seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping( def _check_beam_search_early_stopping(
self, self,
......
...@@ -6,7 +6,6 @@ We are also not going to accept PRs modifying this file, please ...@@ -6,7 +6,6 @@ We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead. change `vllm/entrypoints/openai/api_server.py` instead.
""" """
import argparse
import json import json
import ssl import ssl
from typing import AsyncGenerator from typing import AsyncGenerator
...@@ -17,9 +16,12 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse ...@@ -17,9 +16,12 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
logger = init_logger("vllm.entrypoints.api_server")
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI() app = FastAPI()
...@@ -80,7 +82,7 @@ async def generate(request: Request) -> Response: ...@@ -80,7 +82,7 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-keyfile", type=str, default=None)
...@@ -108,6 +110,14 @@ if __name__ == "__main__": ...@@ -108,6 +110,14 @@ if __name__ == "__main__":
engine_args, usage_context=UsageContext.API_SERVER) engine_args, usage_context=UsageContext.API_SERVER)
app.root_path = args.root_path app.root_path = args.root_path
logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
......
...@@ -13,6 +13,7 @@ from vllm.logger import init_logger ...@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -121,6 +122,11 @@ class LLM: ...@@ -121,6 +122,11 @@ class LLM:
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
"image_input_shape", "image_input_type")
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -250,6 +256,7 @@ class LLM: ...@@ -250,6 +256,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -266,6 +273,8 @@ class LLM: ...@@ -266,6 +273,8 @@ class LLM:
prompts and it is paired one by one with the prompt. prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
...@@ -299,7 +308,7 @@ class LLM: ...@@ -299,7 +308,7 @@ class LLM:
inputs=inputs, inputs=inputs,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) prompt_adapter_request=prompt_adapter_request)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
...@@ -392,6 +401,7 @@ class LLM: ...@@ -392,6 +401,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -407,6 +417,8 @@ class LLM: ...@@ -407,6 +417,8 @@ class LLM:
use the default pooling parameters. use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
...@@ -440,6 +452,7 @@ class LLM: ...@@ -440,6 +452,7 @@ class LLM:
inputs=inputs, inputs=inputs,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
...@@ -499,6 +512,7 @@ class LLM: ...@@ -499,6 +512,7 @@ class LLM:
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
...@@ -521,19 +535,23 @@ class LLM: ...@@ -521,19 +535,23 @@ class LLM:
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
) prompt_adapter_request=prompt_adapter_request)
def _add_request( def _add_request(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(
inputs, request_id,
params, inputs,
lora_request=lora_request) params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
def _run_engine( def _run_engine(
self, *, use_tqdm: bool self, *, use_tqdm: bool
...@@ -545,11 +563,13 @@ class LLM: ...@@ -545,11 +563,13 @@ class LLM:
total=num_requests, total=num_requests,
desc="Processed prompts", desc="Processed prompts",
dynamic_ncols=True, dynamic_ncols=True,
postfix=f"Generation Speed: {0:.2f} toks/s", postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
) )
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_toks = 0 total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
...@@ -558,10 +578,15 @@ class LLM: ...@@ -558,10 +578,15 @@ class LLM:
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput # Calculate tokens only for RequestOutput
total_toks += sum( total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs) len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"] out_spd = total_out_toks / pbar.format_dict[
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" "elapsed"]
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1) pbar.update(1)
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
......
...@@ -8,30 +8,41 @@ from typing import Optional, Set ...@@ -8,30 +8,41 @@ from typing import Optional, Set
import fastapi import fastapi
import uvicorn import uvicorn
from fastapi import Request from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app from prometheus_client import make_asgi_app
from starlette.routing import Mount from starlette.routing import Mount
import vllm
import vllm.envs as envs import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingRequest, ErrorResponse) DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding openai_serving_embedding: OpenAIServingEmbedding
...@@ -57,47 +68,57 @@ async def lifespan(app: fastapi.FastAPI): ...@@ -57,47 +68,57 @@ async def lifespan(app: fastapi.FastAPI):
yield yield
app = fastapi.FastAPI(lifespan=lifespan) router = APIRouter()
def parse_args():
parser = make_arg_parser()
return parser.parse_args()
# Add prometheus asgi middleware to route /metrics requests # Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app()) route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics # Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$') route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route) router.routes.append(route)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@app.get("/health") @router.get("/health")
async def health() -> Response: async def health() -> Response:
"""Health check.""" """Health check."""
await openai_serving_chat.engine.check_health() await openai_serving_chat.engine.check_health()
return Response(status_code=200) return Response(status_code=200)
@app.get("/v1/models") @router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump())
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump())
@router.get("/v1/models")
async def show_available_models(): async def show_available_models():
models = await openai_serving_chat.show_available_models() models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
@app.get("/version") @router.get("/version")
async def show_version(): async def show_version():
ver = {"version": vllm.__version__} ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver) return JSONResponse(content=ver)
@app.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
generator = await openai_serving_chat.create_chat_completion( generator = await openai_serving_chat.create_chat_completion(
...@@ -113,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -113,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion( generator = await openai_serving_completion.create_completion(
request, raw_request) request, raw_request)
...@@ -127,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -127,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding( generator = await openai_serving_embedding.create_embedding(
request, raw_request) request, raw_request)
...@@ -138,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ...@@ -138,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
if __name__ == "__main__": def build_app(args):
args = parse_args() app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -149,6 +172,12 @@ if __name__ == "__main__": ...@@ -149,6 +172,12 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers, allow_headers=args.allowed_headers,
) )
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key: if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http") @app.middleware("http")
...@@ -174,7 +203,13 @@ if __name__ == "__main__": ...@@ -174,7 +203,13 @@ if __name__ == "__main__":
raise ValueError(f"Invalid middleware {middleware}. " raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.") f"Must be a function or a class.")
logger.info("vLLM API server version %s", vllm.__version__) return app
def run_server(args, llm_engine=None):
app = build_app(args)
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
if args.served_model_name is not None: if args.served_model_name is not None:
...@@ -182,19 +217,12 @@ if __name__ == "__main__": ...@@ -182,19 +217,12 @@ if __name__ == "__main__":
else: else:
served_model_names = [args.model] served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args) global engine, engine_args
# Enforce pixel values as image input type for vision language models engine_args = AsyncEngineArgs.from_cli_args(args)
# when serving with API server engine = (llm_engine
if engine_args.image_input_type is not None and \ if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args.image_input_type.upper() != "PIXEL_VALUES": engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
raise ValueError(
f"Invalid image_input_type: {engine_args.image_input_type}. "
"Only --image-input-type 'pixel_values' is supported for serving "
"vision language models with the vLLM API server.")
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
event_loop: Optional[asyncio.AbstractEventLoop] event_loop: Optional[asyncio.AbstractEventLoop]
try: try:
...@@ -210,16 +238,29 @@ if __name__ == "__main__": ...@@ -210,16 +238,29 @@ if __name__ == "__main__":
# When using single vLLM without engine_use_ray # When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config()) model_config = asyncio.run(engine.get_model_config())
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
openai_serving_chat = OpenAIServingChat(engine, model_config, openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names, served_model_names,
args.response_role, args.response_role,
args.lora_modules, args.lora_modules,
args.chat_template) args.chat_template)
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules) engine, model_config, served_model_names, args.lora_modules,
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names) served_model_names)
app.root_path = args.root_path app.root_path = args.root_path
logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
...@@ -229,3 +270,13 @@ if __name__ == "__main__": ...@@ -229,3 +270,13 @@ if __name__ == "__main__":
ssl_certfile=args.ssl_certfile, ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs, ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs) ssl_cert_reqs=args.ssl_cert_reqs)
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
run_server(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment