Commit 99b471c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.1

parents 1925d2e9 468d761b
from .communication_op import *
from .parallel_state import *
from .utils import *
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import pynccl_utils from .parallel_state import (get_tensor_model_parallel_group,
from vllm.model_executor.parallel_utils.custom_all_reduce import ( get_tensor_model_parallel_rank,
custom_all_reduce) get_tensor_model_parallel_world_size,
from vllm.model_executor.parallel_utils.parallel_state import ( is_pynccl_enabled_for_all_reduce)
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: ...@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return TLDR: always assume this function modifies its input, but use the return
value as the output. value as the output.
""" """
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
...@@ -142,7 +144,7 @@ def broadcast_tensor_dict( ...@@ -142,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]: ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
...@@ -155,10 +157,10 @@ def broadcast_tensor_dict( ...@@ -155,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if rank == src: if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert value.is_cuda, ( assert value.is_cuda, (
...@@ -171,19 +173,27 @@ def broadcast_tensor_dict( ...@@ -171,19 +173,27 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=group)
async_handles = []
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group) async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
src=src, src=src,
group=group) group=group)
metadata_list = recv_metadata_list[0] assert recv_metadata_list[0] is not None
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
for key, value in metadata_list: for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
dtype=value.dtype, dtype=value.dtype,
......
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Any, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
try: try:
import pynvml import pynvml
...@@ -19,12 +18,15 @@ except ImportError: ...@@ -19,12 +18,15 @@ except ImportError:
logger = init_logger(__name__) logger = init_logger(__name__)
_CA_HANDLE = None _CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False _IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None: def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE global _CA_HANDLE
if _CA_HANDLE is not None: if _CA_HANDLE is not None:
return return
...@@ -41,19 +43,39 @@ def init_custom_ar() -> None: ...@@ -41,19 +43,39 @@ def init_custom_ar() -> None:
" disable_custom_all_reduce=True explicitly.", world_size, " disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES)) str(_SUPPORTED_WORLD_SIZES))
return return
if not _can_p2p(rank, world_size): num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warn( logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P" "Cannot test GPU P2P because not all GPUs are visible to the "
" capability or P2P test failed. To silence this warning, specify" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" disable_custom_all_reduce=True explicitly.") " is set.")
return return
full_nvlink = _is_full_nvlink(rank, world_size) # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = list(
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warn( logger.warn(
"Custom allreduce is disabled because it's not supported on more" "Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify" " than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")
return return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
...@@ -95,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: ...@@ -95,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle() ca_handle = get_handle()
# when custom allreduce is disabled, this will be None # when custom allreduce is disabled, this will be None
if ca_handle is None: if ca_handle is None:
return return None
if is_capturing(): if is_capturing():
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input): if ca_handle.should_custom_ar(input):
...@@ -113,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: ...@@ -113,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input): if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input) return ca_handle.all_reduce_unreg(input)
return None
@contextmanager @contextmanager
def _nvml(): def _nvml():
...@@ -123,59 +147,41 @@ def _nvml(): ...@@ -123,59 +147,41 @@ def _nvml():
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml() @_nvml()
def _is_full_nvlink(rank, world_size): def _is_full_nvlink(device_ids: List[int]) -> bool:
handle = pynvml.nvmlDeviceGetHandleByIndex(rank) """
for i in range(world_size): query if the set of gpus are fully connected by nvlink (1 hop)
if i != rank: Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
try: so it works on real physical device ids.
link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) """
if not link_state: handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
return False
return True return True
def _can_p2p(rank: int, world_size: int) -> bool: def _can_p2p(rank: int, world_size: int) -> bool:
num_dev = torch.cuda.device_count() from vllm.distributed.utils import gpu_p2p_access_check
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warn(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
for i in range(world_size): for i in range(world_size):
if i == rank: if i == rank:
continue continue
if not torch.cuda.can_device_access_peer(rank, i): if not gpu_p2p_access_check(rank, i):
return False
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
if not _can_actually_p2p(rank, i):
return False return False
return True return True
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c)
class CustomAllreduce: class CustomAllreduce:
# max_size: max supported allreduce size # max_size: max supported allreduce size
...@@ -220,14 +226,14 @@ class CustomAllreduce: ...@@ -220,14 +226,14 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data) return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data): def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data) dist.all_gather_object(all_data, shard_data)
handles = [] handles = []
offsets = [] offsets = []
for i in range(len(all_data)): for i in range(len(all_data)):
handles.append(all_data[i][0]) handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) offsets.append(all_data[i][1]) # type: ignore
return handles, offsets return handles, offsets
def register_buffer(self, inp: torch.Tensor): def register_buffer(self, inp: torch.Tensor):
......
...@@ -20,40 +20,37 @@ ...@@ -20,40 +20,37 @@
# variable in the code. # variable in the code.
import ctypes import ctypes
import datetime import platform
import os from typing import Optional, Union
# ===================== import region ===================== # ===================== import region =====================
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
logger = init_logger(__name__) logger = init_logger(__name__)
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") so_file = find_nccl_library()
# manually load the nccl library
if so_file:
logger.info(
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
if torch.version.cuda is not None:
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}")
try: try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file) nccl = ctypes.CDLL(so_file)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to load NCCL library from {so_file} ." f"Failed to load NCCL library from {so_file} ."
"It is expected if you are not running on NVIDIA/AMD GPUs." "It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH" "Otherwise, the nccl library might not exist, be corrupted "
f"or it does not support the current platform {platform.platform()}."
f"One solution is to download libnccl2 version 2.18 from "
f"https://developer.download.nvidia.com/compute/cuda/repos/ "
f"and extract the libnccl.so.2 file. If you already have the "
f"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.") " to point to the correct nccl library path.")
raise e raise e
...@@ -63,6 +60,18 @@ except Exception as e: ...@@ -63,6 +60,18 @@ except Exception as e:
ncclResult_t = ctypes.c_int ncclResult_t = ctypes.c_int
_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]
def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")
# equivalent to c declaration: # equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version); # ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion _c_ncclGetVersion = nccl.ncclGetVersion
...@@ -72,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] ...@@ -72,8 +81,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str: def ncclGetVersion() -> str:
version = ctypes.c_int() version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version)) NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
assert result == 0
# something like 21903 --> "2.19.3" # something like 21903 --> "2.19.3"
version_str = str(version.value) version_str = str(version.value)
major = version_str[0].lstrip("0") major = version_str[0].lstrip("0")
...@@ -95,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] ...@@ -95,8 +103,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId: def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId() unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
assert result == 0
return unique_id return unique_id
...@@ -111,9 +118,10 @@ _c_ncclCommInitRank.argtypes = [ ...@@ -111,9 +118,10 @@ _c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
] ]
ncclDataType_t = ctypes.c_int
# enums
class ncclDataType_t(ctypes.c_int): class ncclDataTypeEnum:
ncclInt8 = 0 ncclInt8 = 0
ncclChar = 0 ncclChar = 0
ncclUint8 = 1 ncclUint8 = 1
...@@ -132,7 +140,7 @@ class ncclDataType_t(ctypes.c_int): ...@@ -132,7 +140,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10 ncclNumTypes = 10
@classmethod @classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8: if dtype == torch.int8:
return cls.ncclInt8 return cls.ncclInt8
if dtype == torch.uint8: if dtype == torch.uint8:
...@@ -152,7 +160,10 @@ class ncclDataType_t(ctypes.c_int): ...@@ -152,7 +160,10 @@ class ncclDataType_t(ctypes.c_int):
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int): ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0 ncclSum = 0
ncclProd = 1 ncclProd = 1
ncclMax = 2 ncclMax = 2
...@@ -161,7 +172,7 @@ class ncclRedOp_t(ctypes.c_int): ...@@ -161,7 +172,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5 ncclNumOps = 5
@classmethod @classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
return cls.ncclSum return cls.ncclSum
if op == ReduceOp.PRODUCT: if op == ReduceOp.PRODUCT:
...@@ -184,8 +195,8 @@ class ncclRedOp_t(ctypes.c_int): ...@@ -184,8 +195,8 @@ class ncclRedOp_t(ctypes.c_int):
_c_ncclAllReduce = nccl.ncclAllReduce _c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int _c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [ _c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
] ]
# equivalent to c declaration: # equivalent to c declaration:
...@@ -199,66 +210,80 @@ class NCCLCommunicator: ...@@ -199,66 +210,80 @@ class NCCLCommunicator:
def __init__( def __init__(
self, self,
backend=None, group: Optional[ProcessGroup] = None,
init_method=None, device: Optional[Union[int, str, torch.device]] = None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
local_rank: int = -1,
): ):
if not dist.is_initialized(): """
backend = backend or "nccl" Args:
assert backend == 'nccl', ( group: the process group to work on. If None, it will use the
"only use nccl backend for starting the NCCL communicator") default process group.
dist.init_process_group(backend=backend, device: the device to bind the NCCLCommunicator to. If None,
init_method=init_method, it will be bind to f"cuda:{local_rank}".
timeout=timeout, It is the caller's responsibility to make sure each communicator
world_size=world_size, is bind to a unique device.
rank=rank, """
store=store, assert dist.is_initialized()
group_name=group_name, group = get_cpu_world_group() if group is None else group
pg_options=pg_options) assert dist.get_backend(group) != dist.Backend.NCCL, (
self.rank = dist.get_rank() "NCCLCommunicator should be attached to a non-NCCL group.")
self.world_size = dist.get_world_size() self.group = group
if local_rank == -1: self.rank = dist.get_rank(group)
local_rank = self.rank self.world_size = dist.get_world_size(group)
self.local_rank = local_rank if self.rank == 0:
torch.cuda.set_device(local_rank)
if rank == 0:
self.unique_id = ncclGetUniqueId() self.unique_id = ncclGetUniqueId()
else: else:
self.unique_id = NcclUniqueId() self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list( tensor = torch.ByteTensor(list(self.unique_id.internal))
self.unique_id.internal)).cuda(local_rank) dist.broadcast(tensor, src=0, group=group)
dist.broadcast(tensor, src=0) byte_list = tensor.tolist()
byte_list = tensor.cpu().tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p() self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size, if device is None:
self.unique_id, rank) local_rank = get_local_rank()
assert result == 0 device = torch.device(f"cuda:{local_rank}")
self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self, def all_reduce(self,
tensor: torch.Tensor, tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
stream=None): stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), NCCL_CHECK(
ctypes.c_void_p(tensor.data_ptr()), _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(), ctypes.c_void_p(tensor.data_ptr()),
ncclDataType_t.from_torch(tensor.dtype), tensor.numel(),
ncclRedOp_t.from_torch(op), self.comm, ncclDataTypeEnum.from_torch(tensor.dtype),
ctypes.c_void_p(stream.cuda_stream)) ncclRedOpTypeEnum.from_torch(op), self.comm,
assert result == 0 ctypes.c_void_p(stream.cuda_stream)))
def __del__(self): def __del__(self):
# `dist` module might have been already destroyed # `dist` module might have been already destroyed
if hasattr(dist, 'destroy_process_group'): if hasattr(dist, 'destroy_process_group'):
dist.destroy_process_group() dist.destroy_process_group()
_c_ncclCommDestroy(self.comm) # function might have been already destroyed
if _c_ncclCommDestroy is not None:
_c_ncclCommDestroy(self.comm)
...@@ -2,15 +2,15 @@ import contextlib ...@@ -2,15 +2,15 @@ import contextlib
from typing import Optional from typing import Optional
import torch import torch
from torch.distributed import ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetVersion) ncclGetVersion)
except Exception as e: except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module # in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs # e.g. when running on machines with AMD GPUs
...@@ -30,28 +30,24 @@ def is_initialized() -> bool: ...@@ -30,28 +30,24 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream): def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication""" """Set the cuda stream for communication"""
try: try:
assert comm is not None
comm.stream = stream comm.stream = stream
yield yield
finally: finally:
pass pass
def init_process_group(world_size: int, def init_process_group(group: Optional[ProcessGroup] = None) -> None:
rank: int,
init_method: str,
local_rank: int = -1) -> None:
assert not is_initialized() assert not is_initialized()
global comm global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}") logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
comm = NCCLCommunicator(init_method=init_method, comm = NCCLCommunicator(group=group)
world_size=world_size,
local_rank=local_rank,
rank=rank)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group.""" """All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor" assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op) comm.all_reduce(input_, op)
...@@ -62,8 +58,9 @@ def destroy_process_group() -> None: ...@@ -62,8 +58,9 @@ def destroy_process_group() -> None:
def get_world_size() -> int: def get_world_size() -> int:
"""Returns the world size.""" """Returns the world size."""
assert comm is not None
return comm.world_size return comm.world_size
def get_nccl_backend(): def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm return comm
...@@ -4,24 +4,89 @@ ...@@ -4,24 +4,89 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """Tensor and pipeline parallel groups."""
import contextlib import contextlib
import os
from typing import Optional
import torch import torch
from vllm.model_executor.parallel_utils import pynccl_utils from vllm.logger import init_logger
logger = init_logger(__name__)
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None
# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None
# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.
# A list of global ranks for each pipeline group to ease calculation of the # A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage. # source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
_LOCAL_RANK = -1
def get_local_rank():
global _LOCAL_RANK
return _LOCAL_RANK
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
logger.debug(f"{world_size=} {rank=} {local_rank=} "
f"{distributed_init_method=} {backend=}")
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank)
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = int(os.environ['LOCAL_RANK'])
global _LOCAL_RANK
_LOCAL_RANK = local_rank
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None: ) -> None:
""" """
Initialize model parallel groups. Initialize model parallel groups.
...@@ -48,6 +113,8 @@ def initialize_model_parallel( ...@@ -48,6 +113,8 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size() world_size: int = torch.distributed.get_world_size()
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if (world_size != if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size): tensor_model_parallel_size * pipeline_model_parallel_size):
...@@ -69,7 +136,7 @@ def initialize_model_parallel( ...@@ -69,7 +136,7 @@ def initialize_model_parallel(
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size) (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks: if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group _TENSOR_MODEL_PARALLEL_GROUP = group
...@@ -80,7 +147,7 @@ def initialize_model_parallel( ...@@ -80,7 +147,7 @@ def initialize_model_parallel(
"pipeline model parallel group is already initialized") "pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups) ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks, backend=backend)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks _PIPELINE_GLOBAL_RANKS = ranks
...@@ -89,14 +156,17 @@ def initialize_model_parallel( ...@@ -89,14 +156,17 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
backend: Optional[str] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized. values if the model parallel groups are initialized.
""" """
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend()
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size) pipeline_model_parallel_size, backend)
return return
assert ( assert (
...@@ -117,6 +187,12 @@ def model_parallel_is_initialized(): ...@@ -117,6 +187,12 @@ def model_parallel_is_initialized():
and _PIPELINE_MODEL_PARALLEL_GROUP is not None) and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def get_cpu_world_group():
"""Get the CPU world group."""
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
return _CPU_WORLD_GROUP
def get_tensor_model_parallel_group(): def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to.""" """Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
...@@ -209,6 +285,7 @@ def destroy_model_parallel(): ...@@ -209,6 +285,7 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
from vllm.distributed.device_communicators import pynccl_utils
# Destroy the pynccl states if any. # Destroy the pynccl states if any.
pynccl_utils.destroy_process_group() pynccl_utils.destroy_process_group()
...@@ -222,6 +299,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False ...@@ -222,6 +299,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager @contextlib.contextmanager
def with_pynccl_for_all_reduce(): def with_pynccl_for_all_reduce():
from vllm.distributed.device_communicators import pynccl_utils
"""use pynccl instead of torch.distributed for all reduce""" """use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1: if tp_size == 1:
......
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import os
from typing import Dict, Optional, Sequence
import torch
import torch.distributed as dist
from vllm.logger import init_logger
from .parallel_state import get_cpu_world_group, get_local_rank
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()
# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]
is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.expanduser(
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info(f"generating GPU P2P access cache for in {path}")
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info(f"reading GPU P2P access cache from {path}")
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]
import argparse import argparse
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
ParallelConfig, SchedulerConfig, TokenizerPoolConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
VisionLanguageConfig) ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
...@@ -14,12 +16,14 @@ class EngineArgs: ...@@ -14,12 +16,14 @@ class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str model: str
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto' tokenizer_mode: str = 'auto'
trust_remote_code: bool = False trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
...@@ -53,8 +57,9 @@ class EngineArgs: ...@@ -53,8 +57,9 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
forced_num_gpu_blocks: Optional[int] = None num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
image_input_type: Optional[str] = None image_input_type: Optional[str] = None
...@@ -64,6 +69,12 @@ class EngineArgs: ...@@ -64,6 +69,12 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False enable_chunked_prefill: bool = False
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
self.tokenizer = self.model self.tokenizer = self.model
...@@ -73,72 +84,79 @@ class EngineArgs: ...@@ -73,72 +84,79 @@ class EngineArgs:
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# NOTE: If you update any of the arguments below, please also
# make sure to update docs/source/models/engine_args.rst
# Model arguments # Model arguments
parser.add_argument( parser.add_argument(
'--model', '--model',
type=str, type=str,
default='facebook/opt-125m', default='facebook/opt-125m',
help='name or path of the huggingface model to use') help='Name or path of the huggingface model to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=str, type=str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use') help='Name or path of the huggingface tokenizer to use.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=str, type=str,
default=None, default=None,
help='the specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=str, type=str,
default=None, default=None,
help='the specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=str,
default=None, default=None,
help='the specific tokenizer version to use. It can be a branch ' help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument('--tokenizer-mode', parser.add_argument(
type=str, '--tokenizer-mode',
default=EngineArgs.tokenizer_mode, type=str,
choices=['auto', 'slow'], default=EngineArgs.tokenizer_mode,
help='tokenizer mode. "auto" will use the fast ' choices=['auto', 'slow'],
'tokenizer if available, and "slow" will ' help='The tokenizer mode.\n\n* "auto" will use the '
'always use the slow tokenizer.') 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
help='trust remote code from huggingface') help='Trust remote code from huggingface.')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='Directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface.')
parser.add_argument( parser.add_argument(
'--load-format', '--load-format',
type=str, type=str,
default=EngineArgs.load_format, default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], choices=[
help='The format of the model weights to load. ' 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
'"auto" will try to load the weights in the safetensors format ' ],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format ' 'and fall back to the pytorch bin format if safetensors format '
'is not available. ' 'is not available.\n'
'"pt" will load the weights in the pytorch bin format. ' '* "pt" will load the weights in the pytorch bin format.\n'
'"safetensors" will load the weights in the safetensors format. ' '* "safetensors" will load the weights in the safetensors format.\n'
'"npcache" will load the weights in pytorch format and store ' '* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. ' 'a numpy cache to speed up the loading.\n'
'"dummy" will initialize the weights with random values, ' '* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.') 'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave which assumes tensorizer_uri is set to the location of '
'the serialized weights.')
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
...@@ -146,80 +164,117 @@ class EngineArgs: ...@@ -146,80 +164,117 @@ class EngineArgs:
choices=[ choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
], ],
help='data type for model weights and activations. ' help='Data type for model weights and activations.\n\n'
'The "auto" option will use FP16 precision ' '* "auto" will use FP16 precision for FP32 and FP16 models, and '
'for FP32 and FP16 models, and BF16 precision ' 'BF16 precision for BF16 models.\n'
'for BF16 models.') '* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument( parser.add_argument(
'--kv-cache-dtype', '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8_e5m2'], choices=['auto', 'fp8'],
default=EngineArgs.kv_cache_dtype, default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. Note FP8 is not supported when cuda version is ' 'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'lower than 11.8.') 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
help='model context length. If unspecified, ' help='Model context length. If unspecified, will '
'will be automatically derived from the model.') 'be automatically derived from the model config.')
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument('--worker-use-ray',
action='store_true', action='store_true',
help='use Ray for distributed serving, will be ' help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU.')
parser.add_argument('--pipeline-parallel-size', parser.add_argument('--pipeline-parallel-size',
'-pp', '-pp',
type=int, type=int,
default=EngineArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='Number of pipeline stages.')
parser.add_argument('--tensor-parallel-size', parser.add_argument('--tensor-parallel-size',
'-tp', '-tp',
type=int, type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='Number of tensor parallel replicas.')
parser.add_argument( parser.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
type=int, type=int,
default=EngineArgs.max_parallel_loading_workers, default=EngineArgs.max_parallel_loading_workers,
help='load model sequentially in multiple batches, ' help='Load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor ' 'to avoid RAM OOM when using tensor '
'parallel and large models') 'parallel and large models.')
parser.add_argument( parser.add_argument(
'--ray-workers-use-nsight', '--ray-workers-use-nsight',
action='store_true', action='store_true',
help='If specified, use nsight to profile ray workers') help='If specified, use nsight to profile Ray workers.')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32, 128], choices=[8, 16, 32],
help='token block size') help='Token block size for contiguous chunks of '
'tokens.')
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='Enables automatic prefix caching') help='Enables automatic prefix caching.')
parser.add_argument('--use-v2-block-manager', parser.add_argument('--use-v2-block-manager',
action='store_true', action='store_true',
help='Use BlockSpaceMangerV2') help='Use BlockSpaceMangerV2.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
help='random seed') help='Random seed for operations.')
parser.add_argument('--swap-space', parser.add_argument('--swap-space',
type=int, type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU.')
parser.add_argument( parser.add_argument(
'--gpu-memory-utilization', '--gpu-memory-utilization',
type=float, type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the fraction of GPU memory to be used for ' help='The fraction of GPU memory to be used for the model '
'the model executor, which can range from 0 to 1.' 'executor, which can range from 0 to 1. For example, a value of '
'If unspecified, will use the default value of 0.9.') '0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9.')
parser.add_argument( parser.add_argument(
'--forced-num-gpu-blocks', '--num-gpu-blocks-override',
type=int, type=int,
default=None, default=None,
help='If specified, ignore GPU profiling result and use this number' help='If specified, ignore GPU profiling result and use this number'
...@@ -227,26 +282,26 @@ class EngineArgs: ...@@ -227,26 +282,26 @@ class EngineArgs:
parser.add_argument('--max-num-batched-tokens', parser.add_argument('--max-num-batched-tokens',
type=int, type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='Maximum number of batched tokens per '
'iteration') 'iteration.')
parser.add_argument('--max-num-seqs', parser.add_argument('--max-num-seqs',
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='Maximum number of sequences per iteration.')
parser.add_argument( parser.add_argument(
'--max-logprobs', '--max-logprobs',
type=int, type=int,
default=EngineArgs.max_logprobs, default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in' help=('Max number of log probs to return logprobs is specified in'
' SamplingParams')) ' SamplingParams.'))
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='Disable logging statistics.')
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', 'gptq', 'squeezellm', None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` ' 'None, we first check the `quantization_config` '
...@@ -262,13 +317,13 @@ class EngineArgs: ...@@ -262,13 +317,13 @@ class EngineArgs:
parser.add_argument('--max-context-len-to-capture', parser.add_argument('--max-context-len-to-capture',
type=int, type=int,
default=EngineArgs.max_context_len_to_capture, default=EngineArgs.max_context_len_to_capture,
help='maximum context length covered by CUDA ' help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce', parser.add_argument('--disable-custom-all-reduce',
action='store_true', action='store_true',
default=EngineArgs.disable_custom_all_reduce, default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig') help='See ParallelConfig.')
parser.add_argument('--tokenizer-pool-size', parser.add_argument('--tokenizer-pool-size',
type=int, type=int,
default=EngineArgs.tokenizer_pool_size, default=EngineArgs.tokenizer_pool_size,
...@@ -324,7 +379,7 @@ class EngineArgs: ...@@ -324,7 +379,7 @@ class EngineArgs:
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
choices=["auto", "cuda", "neuron"], choices=["auto", "cuda", "neuron", "cpu"],
help='Device type for vLLM execution.') help='Device type for vLLM execution.')
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
parser.add_argument( parser.add_argument(
...@@ -359,10 +414,41 @@ class EngineArgs: ...@@ -359,10 +414,41 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False, help='If set, the prefill requests can be chunked based on the '
help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.')
'max_num_batched_tokens')
parser.add_argument(
'--speculative-model',
type=str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=str,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument('--model-loader-extra-config',
type=str,
default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. '
'This will be passed to the model loader '
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
return parser return parser
@classmethod @classmethod
...@@ -373,23 +459,19 @@ class EngineArgs: ...@@ -373,23 +459,19 @@ class EngineArgs:
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args return engine_args
def create_engine_configs( def create_engine_config(self, ) -> EngineConfig:
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig],
Optional[VisionLanguageConfig]]:
device_config = DeviceConfig(self.device) device_config = DeviceConfig(self.device)
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.download_dir, self.load_format, self.trust_remote_code, self.dtype, self.seed, self.revision,
self.dtype, self.seed, self.revision, self.code_revision, self.code_revision, self.tokenizer_revision, self.max_model_len,
self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture, self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs) self.max_logprobs, self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
self.forced_num_gpu_blocks, self.num_gpu_blocks_override,
model_config.get_sliding_window(), model_config.get_sliding_window(),
self.enable_prefix_caching) self.enable_prefix_caching)
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
...@@ -401,11 +483,26 @@ class EngineArgs: ...@@ -401,11 +483,26 @@ class EngineArgs:
self.tokenizer_pool_type, self.tokenizer_pool_type,
self.tokenizer_pool_extra_config, self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight) ), self.ray_workers_use_nsight)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
)
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
self.max_num_batched_tokens, self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,
self.use_v2_block_manager, self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor, delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
) )
...@@ -417,6 +514,12 @@ class EngineArgs: ...@@ -417,6 +514,12 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
)
if self.image_input_type: if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size): or not self.image_feature_size):
...@@ -433,8 +536,19 @@ class EngineArgs: ...@@ -433,8 +536,19 @@ class EngineArgs:
else: else:
vision_language_config = None vision_language_config = None
return (model_config, cache_config, parallel_config, scheduler_config, decoding_config = DecodingConfig(
device_config, lora_config, vision_language_config) guided_decoding_backend=self.guided_decoding_backend)
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config)
@dataclass @dataclass
...@@ -445,20 +559,31 @@ class AsyncEngineArgs(EngineArgs): ...@@ -445,20 +559,31 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(parser: argparse.ArgumentParser,
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: async_args_only: bool = False) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser) if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', parser.add_argument('--engine-use-ray',
action='store_true', action='store_true',
help='use Ray to start the LLM engine in a ' help='Use Ray to start the LLM engine in a '
'separate process as the server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='disable logging requests') help='Disable logging requests.')
parser.add_argument('--max-log-len', parser.add_argument('--max-log-len',
type=int, type=int,
default=None, default=None,
help='max number of prompt characters or prompt ' help='Max number of prompt characters or prompt '
'ID numbers being printed in log. ' 'ID numbers being printed in log.'
'Default: unlimited.') '\n\nDefault: Unlimited')
return parser return parser
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
async_args_only=True)
...@@ -2,8 +2,8 @@ import asyncio ...@@ -2,8 +2,8 @@ import asyncio
import os import os
import time import time
from functools import partial from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -52,7 +52,7 @@ class AsyncStream: ...@@ -52,7 +52,7 @@ class AsyncStream:
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
self.request_id = request_id self.request_id = request_id
self._queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None: def put(self, item: Union[RequestOutput, Exception]) -> None:
...@@ -217,7 +217,15 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -217,7 +217,15 @@ class _AsyncLLMEngine(LLMEngine):
else: else:
output = [] output = []
return self._process_model_outputs(output, scheduler_outputs) request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs
async def encode_request_async( async def encode_request_async(
self, self,
...@@ -310,15 +318,17 @@ class AsyncLLMEngine: ...@@ -310,15 +318,17 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
self.background_loop = None self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
# task as well to prevent it from being garbage # task as well to prevent it from being garbage
# collected # collected
self._background_loop_unshielded = None self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Lazy initialized fields
self._request_tracker: RequestTracker
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
...@@ -328,28 +338,31 @@ class AsyncLLMEngine: ...@@ -328,28 +338,31 @@ class AsyncLLMEngine:
) -> "AsyncLLMEngine": ) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_config = engine_args.create_engine_config()
parallel_config = engine_configs[2]
device_config = engine_configs[4] if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
if device_config.device_type == "neuron": executor_class = NeuronExecutorAsync
raise NotImplementedError("Neuron is not supported for " elif engine_config.device_config.device_type == "cpu":
"async engine yet.") assert not engine_config.parallel_config.worker_use_ray, (
elif parallel_config.worker_use_ray or engine_args.engine_use_ray: "Ray is not supported with the CPU backend.")
initialize_ray_cluster(parallel_config) from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync executor_class = RayGPUExecutorAsync
else: else:
assert parallel_config.world_size == 1, ( assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync executor_class = GPUExecutorAsync
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
parallel_config.worker_use_ray, engine_config.parallel_config.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, **engine_config.to_dict(),
executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len, max_log_len=engine_args.max_log_len,
...@@ -361,11 +374,13 @@ class AsyncLLMEngine: ...@@ -361,11 +374,13 @@ class AsyncLLMEngine:
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return (self.background_loop is not None return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done()) and not self._background_loop_unshielded.done())
@property @property
def is_stopped(self) -> bool: def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done()) and self._background_loop_unshielded.done())
@property @property
...@@ -381,7 +396,7 @@ class AsyncLLMEngine: ...@@ -381,7 +396,7 @@ class AsyncLLMEngine:
async def get_tokenizer(self) -> "PreTrainedTokenizer": async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() return await self.engine.get_tokenizer.remote() # type: ignore
else: else:
return self.engine.get_tokenizer() return self.engine.get_tokenizer()
...@@ -411,8 +426,8 @@ class AsyncLLMEngine: ...@@ -411,8 +426,8 @@ class AsyncLLMEngine:
else: else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the # FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments. # order of the arguments.
cache_config = args[1] cache_config = kwargs["cache_config"]
parallel_config = args[2] parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1: if parallel_config.tensor_parallel_size == 1:
num_gpus = cache_config.gpu_memory_utilization num_gpus = cache_config.gpu_memory_utilization
else: else:
...@@ -434,7 +449,8 @@ class AsyncLLMEngine: ...@@ -434,7 +449,8 @@ class AsyncLLMEngine:
# TODO: Maybe add add_request_batch to reduce Ray overhead # TODO: Maybe add add_request_batch to reduce Ray overhead
try: try:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote(**new_request) await self.engine.add_request.remote( # type: ignore
**new_request)
else: else:
await self.engine.add_request_async(**new_request) await self.engine.add_request_async(**new_request)
except ValueError as e: except ValueError as e:
...@@ -449,7 +465,7 @@ class AsyncLLMEngine: ...@@ -449,7 +465,7 @@ class AsyncLLMEngine:
await self._engine_abort(finished_requests) await self._engine_abort(finished_requests)
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() request_outputs = await self.engine.step.remote() # type: ignore
else: else:
request_outputs = await self.engine.step_async() request_outputs = await self.engine.step_async()
...@@ -462,7 +478,7 @@ class AsyncLLMEngine: ...@@ -462,7 +478,7 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) await self.engine.abort_request.remote(request_ids) # type: ignore
else: else:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
...@@ -525,11 +541,12 @@ class AsyncLLMEngine: ...@@ -525,11 +541,12 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
if self.engine_use_ray: if self.engine_use_ray:
prompt_token_ids = await self.engine.encode_request_async.remote( prompt_token_ids = await (
request_id=request_id, self.engine.encode_request_async.remote( # type: ignore
prompt=prompt, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt=prompt,
lora_request=lora_request) prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
else: else:
prompt_token_ids = await self.engine.encode_request_async( prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id, request_id=request_id,
...@@ -676,13 +693,13 @@ class AsyncLLMEngine: ...@@ -676,13 +693,13 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_model_config.remote() return await self.engine.get_model_config.remote() # type: ignore
else: else:
return self.engine.get_model_config() return self.engine.get_model_config()
async def do_log_stats(self) -> None: async def do_log_stats(self) -> None:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.do_log_stats.remote() await self.engine.do_log_stats.remote() # type: ignore
else: else:
self.engine.do_log_stats() self.engine.do_log_stats()
...@@ -695,7 +712,7 @@ class AsyncLLMEngine: ...@@ -695,7 +712,7 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
try: try:
await self.engine.check_health.remote() await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e: except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e raise RuntimeError("Engine is dead.") from e
else: else:
......
import time import time
from typing import Iterable, List, Optional, Tuple, Type, Union from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
import vllm import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.engine.ray_utils import initialize_ray_cluster from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader import get_architecture_class_name
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceGroup, SequenceStage)
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
...@@ -30,6 +34,17 @@ logger = init_logger(__name__) ...@@ -30,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
...@@ -53,6 +68,11 @@ class LLMEngine: ...@@ -53,6 +68,11 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution. parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device. device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision
language models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
...@@ -66,8 +86,11 @@ class LLMEngine: ...@@ -66,8 +86,11 @@ class LLMEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional["VisionLanguageConfig"], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -75,22 +98,26 @@ class LLMEngine: ...@@ -75,22 +98,26 @@ class LLMEngine:
logger.info( logger.info(
f"Initializing an LLM engine (v{vllm.__version__}) with config: " f"Initializing an LLM engine (v{vllm.__version__}) with config: "
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, " f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, " f"tokenizer_revision={model_config.tokenizer_revision}, "
f"trust_remote_code={model_config.trust_remote_code}, " f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, " f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={load_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={load_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce=" f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, " f"{parallel_config.disable_custom_all_reduce}, "
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, " f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, " f"kv_cache_dtype={cache_config.cache_dtype}, "
f"quantization_param_path={model_config.quantization_param_path}, "
f"device_config={device_config.device}, " f"device_config={device_config.device}, "
f"decoding_config={decoding_config!r}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
...@@ -101,20 +128,41 @@ class LLMEngine: ...@@ -101,20 +128,41 @@ class LLMEngine:
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args()
self._init_tokenizer() if not self.model_config.skip_tokenizer_init:
self.detokenizer = Detokenizer(self.tokenizer) self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
load_config=load_config,
)
self.model_executor = executor_class(model_config, cache_config, self._initialize_kv_caches()
parallel_config, scheduler_config,
device_config, lora_config,
vision_language_config)
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled(): if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage( usage_message.report_usage(
get_architecture_class_name(model_config), get_architecture_class_name(model_config),
usage_context, usage_context,
...@@ -146,9 +194,10 @@ class LLMEngine: ...@@ -146,9 +194,10 @@ class LLMEngine:
parallel_config.disable_custom_all_reduce, parallel_config.disable_custom_all_reduce,
}) })
# Ping the tokenizer to ensure liveness if it runs in a if self.tokenizer:
# different process. # Ping the tokenizer to ensure liveness if it runs in a
self.tokenizer.ping() # different process.
self.tokenizer.ping()
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
...@@ -162,6 +211,41 @@ class LLMEngine: ...@@ -162,6 +211,41 @@ class LLMEngine:
labels=dict(model_name=model_config.model)) labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config) self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
),
))
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
num_gpu_blocks, num_cpu_blocks = (
self.model_executor.determine_num_available_blocks())
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(f"Overriding {num_gpu_blocks=} with "
f"{num_gpu_blocks_override=}")
num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
...@@ -170,27 +254,28 @@ class LLMEngine: ...@@ -170,27 +254,28 @@ class LLMEngine:
) -> "LLMEngine": ) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments.""" """Creates an LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_config = engine_args.create_engine_config()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
if device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif parallel_config.worker_use_ray: elif engine_config.device_config.device_type == "cpu":
initialize_ray_cluster(parallel_config) from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor executor_class = RayGPUExecutor
else: else:
assert parallel_config.world_size == 1, ( assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor executor_class = GPUExecutor
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
*engine_configs, **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
...@@ -219,7 +304,7 @@ class LLMEngine: ...@@ -219,7 +304,7 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision) revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs) init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs) self.parallel_config.tokenizer_pool_config, **init_kwargs)
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -316,8 +401,13 @@ class LLMEngine: ...@@ -316,8 +401,13 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.tokenizer.get_lora_tokenizer( eos_token_id = None
lora_request).eos_token_id if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request) eos_token_id, lora_request)
...@@ -327,6 +417,8 @@ class LLMEngine: ...@@ -327,6 +417,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens # inject the eos token id into the sampling_params to support min_tokens
# processing # processing
sampling_params.eos_token_id = seq.eos_token_id sampling_params.eos_token_id = seq.eos_token_id
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
...@@ -366,235 +458,35 @@ class LLMEngine: ...@@ -366,235 +458,35 @@ class LLMEngine:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self.detokenizer.decode_sequence_inplace(seq,
seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs( def _process_model_outputs(
self, output: SamplerOutput, self, output: List[SamplerOutput],
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""
now = time.time() now = time.time()
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): # Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
output_by_sequence_group):
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.update_num_computed_tokens(
seq_group.update_num_computed_tokens(token_chunk_size) scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs)
# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
...@@ -606,13 +498,9 @@ class LLMEngine: ...@@ -606,13 +498,9 @@ class LLMEngine:
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
...@@ -670,22 +558,42 @@ class LLMEngine: ...@@ -670,22 +558,42 @@ class LLMEngine:
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model( output = self.model_executor.execute_model(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs.blocks_to_swap_out, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_copy) blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
else: else:
output = [] output = []
return self._process_model_outputs(output, scheduler_outputs) request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
# Log stats.
if self.log_stats:
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output=output))
return request_outputs
def do_log_stats(self) -> None: def do_log_stats(self) -> None:
"""Forced log when no requests active.""" """Forced log when no requests active."""
if self.log_stats: if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None)) self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _get_stats(self, def _get_stats(
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: self,
"""Get Stats to be Logged to Prometheus.""" scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus.
Args:
scheduler_outputs: Optional, used to populate metrics related to
the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics
which are created by the workers.
"""
now = time.time() now = time.time()
# KV Cache Usage in %. # KV Cache Usage in %.
...@@ -712,7 +620,7 @@ class LLMEngine: ...@@ -712,7 +620,7 @@ class LLMEngine:
time_per_output_tokens = [] time_per_output_tokens = []
time_e2e_requests = [] time_e2e_requests = []
if scheduler_outputs is not None: if scheduler_outputs is not None:
prompt_run = scheduler_outputs.prompt_run prompt_run = scheduler_outputs.num_prefill_groups > 0
# Number of Tokens. # Number of Tokens.
if prompt_run: if prompt_run:
...@@ -742,6 +650,14 @@ class LLMEngine: ...@@ -742,6 +650,14 @@ class LLMEngine:
time_to_first_tokens = time_last_iters if prompt_run else [] time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters time_per_output_tokens = [] if prompt_run else time_last_iters
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
return Stats( return Stats(
now=now, now=now,
num_running=num_running, num_running=num_running,
...@@ -754,58 +670,9 @@ class LLMEngine: ...@@ -754,58 +670,9 @@ class LLMEngine:
time_to_first_tokens=time_to_first_tokens, time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens, time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests, time_e2e_requests=time_e2e_requests,
spec_decode_metrics=spec_decode_metrics,
) )
def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _finalize_sequence(self, seq: Sequence,
sampling_params: SamplingParams,
stop_string: str) -> None:
if sampling_params.include_stop_str_in_output:
return
if stop_string and seq.output_text.endswith(stop_string):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)]
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request) return self.model_executor.add_lora(lora_request)
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
...@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, ...@@ -8,6 +8,9 @@ from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
logger = init_logger(__name__) logger = init_logger(__name__)
disable_created_metrics() disable_created_metrics()
...@@ -118,13 +121,21 @@ class Stats: ...@@ -118,13 +121,21 @@ class Stats:
time_per_output_tokens: List[float] time_per_output_tokens: List[float]
time_e2e_requests: List[float] time_e2e_requests: List[float]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger: class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" """StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally. # Metadata for logging locally.
self.last_local_log = time.monotonic() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
...@@ -135,7 +146,7 @@ class StatLogger: ...@@ -135,7 +146,7 @@ class StatLogger:
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys())) self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info()) self.metrics.info_cache_config.info(obj.metrics_info())
...@@ -229,3 +240,19 @@ class StatLogger: ...@@ -229,3 +240,19 @@ class StatLogger:
self.num_prompt_tokens = [] self.num_prompt_tokens = []
self.num_generation_tokens = [] self.num_generation_tokens = []
self.last_local_log = stats.now self.last_local_log = stats.now
if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
return ("Speculative metrics: "
f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
f"System efficiency: {metrics.system_efficiency:.3f}, "
f"Number of speculative tokens: {metrics.num_spec_tokens}, "
f"Number of accepted tokens: {metrics.accepted_tokens}, "
f"Number of draft tokens tokens: {metrics.draft_tokens}, "
f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
from abc import ABC, abstractmethod
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
"""Interface for logic that processes new token ids in sequence groups,
managing detokenization, stop checking, and freeing/forking sequences with
the scheduler.
This is highly coupled with the LLMEngine and should be seen as an extension
of it. The logic is separated to simplify the LLMEngine class and allow
separate implementations for single-step decoding (which supports beam
search sequence forking) and multi-step decoding (which does not support
beam search, but does support speculative decoding).
"""
@staticmethod
def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if scheduler_config.num_lookahead_slots == 0:
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
return SingleStepOutputProcessor(
scheduler_config,
detokenizer,
scheduler,
seq_counter,
stop_checker,
)
else:
# Importing here to avoid cycle.
from vllm.engine.output_processor.multi_step import (
MultiStepOutputProcessor)
return MultiStepOutputProcessor(
detokenizer,
scheduler,
seq_counter,
get_tokenizer_for_seq,
stop_checker,
)
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
"""
pass
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles logic related to
detokenization and stopping conditions. It specializes to "multi-step
decoding", where vLLM's worker may generate multiple tokens per invocation.
This is currently mutually exclusive with advanced sampling techniques like
beam search, which motivates the separation of this logic from the single
step output processor.
This class is responsible for things such as correctly appending all new
token ids to their sequence, detokenizing new token ids, truncating new
output tokens after an eos token, and correctly handling the case where the
number of new output tokens per sequence differs in a single batch.
"""
def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
one new token per sequence.
This applies logic like stop condition checking and detokenization,
including freeing finished sequences. It also handles cases where there
are tokens emitted after the EOS token.
"""
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
assert seqs, "expected running sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [outputs[step].samples[0] for step in range(len(outputs))]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if not sampling_params.ignore_eos:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id in output_token_ids:
seq.append_token_id(
token_id=output_token_id,
# TODO emit logprobs in multi-step decoding.
logprobs={output_token_id: Logprob(0.0)},
)
new_char_count = 0
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params)
if seq.is_finished():
break
if seq.is_finished():
self.scheduler.free_seq(seq)
from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
scheduling of the next batch. Output processing logic includes
detokenization, and determining if a sequence is finished (e.g. via max len
or eos token).
The SingleStepOutputProcessor is specialized to the case where the model
emits at most a single token per invocation, which precludes configurations
such as speculative decoding or multi-step decoding. This enables beam
search sampling, which requires forking/finishing/freeing sequences in a way
that is currently difficult to schedule multiple steps ahead of time.
"""
def __init__(
self,
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
self.detokenizer = detokenizer
self.scheduler = scheduler
self.seq_counter = seq_counter
self.stop_checker = stop_checker
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append all new tokens to sequences in the sequence group. Fork any
surviving beam candidates; free any unsurviving ones.
Invokes detokenizer to detokenize new tokens, and also marks sequences
as finished if they meet stop conditions.
"""
assert (len(outputs) == 1
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None and \
seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop
checking. This checks things such as: whether the eos token was emitted,
whether the max_tokens has been consumed, whether a stop string has been
emitted, or if we have exceeded the max model len.
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
self.max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if not new_char_count:
return None
for stop_str in sampling_params.stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
if stop_index == -1:
continue
if sampling_params.include_stop_str_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
# No truncation required.
return stop_str
# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return None
from typing import List
from vllm.sequence import SamplerOutput
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
num_seq_groups: int):
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SamplerOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
return output_by_sequence_group
...@@ -3,47 +3,26 @@ from typing import List, Optional, Tuple ...@@ -3,47 +3,26 @@ from typing import List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices from vllm.utils import get_ip, is_hip
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
import ray import ray
class RayWorkerVllm: class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, init_cached_hf_modules=False) -> None: def __init__(self, *args, **kwargs) -> None:
if init_cached_hf_modules: super().__init__(*args, **kwargs)
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
# Since the compiled DAG runs a main execution # Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device. # in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on # The flag indicates is set_device is called on
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
try:
executor = getattr(self, method)
return executor(*args, **kwargs)
except Exception as e:
# exceptions in ray worker may cause deadlock
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
...@@ -52,9 +31,6 @@ try: ...@@ -52,9 +31,6 @@ try:
gpu_ids = ray.get_gpu_ids() gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids return node_id, gpu_ids
def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)
def execute_model_compiled_dag_remote(self, ignored): def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled.""" """Used only when compiled DAG is enabled."""
import torch import torch
...@@ -70,8 +46,8 @@ except ImportError as e: ...@@ -70,8 +46,8 @@ except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray`.") "`pip install ray`.")
ray = None ray = None # type: ignore
RayWorkerVllm = None RayWorkerWrapper = None # type: ignore
def initialize_ray_cluster( def initialize_ray_cluster(
......
...@@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: ...@@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment