Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
......@@ -192,6 +192,11 @@ class BlockAllocator(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass
class NoFreeBlocksError(ValueError):
pass
......@@ -297,6 +302,11 @@ class DeviceAwareBlockAllocator(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass
@abstractmethod
def find_cached_blocks_prefix(
self,
......
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
......@@ -136,16 +136,18 @@ class NaiveBlockAllocator(BlockAllocator):
self._refcounter.incr(block_id)
return block_id
def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
if isinstance(block, Block):
block_id = block.block_id
block.block_id = None
else:
block_id = block
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
block.block_id = None
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
......@@ -154,6 +156,9 @@ class NaiveBlockAllocator(BlockAllocator):
if not keep_block_object:
self._block_pool.free_block(block)
def free_block_id(self, block_id: BlockId) -> None:
self._free_block_id(block_id)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
......@@ -325,6 +330,10 @@ class NaiveBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float:
return -1
def reset_prefix_cache(self) -> bool:
"""No prefix cache for naive block allocator."""
return True
def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []
......
......@@ -12,6 +12,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device,
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.logger import init_logger
from vllm.sequence import Sequence
PrefixHash = int
......@@ -21,6 +22,8 @@ PrefixHash = int
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
logger = init_logger(__name__)
class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
......@@ -105,7 +108,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
self.eviction_policy = eviction_policy
self.evictor: Evictor = make_evictor(self.eviction_policy)
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
......@@ -428,6 +432,44 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.get_num_total_blocks() -
self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Free all blocks in the evictor.
while (block_id :=
self._maybe_allocate_evicted_block_id()) is not None:
self._hashless_allocator.free_block_id(block_id)
# Should not have any cached blocks because all blocks are evicted.
assert not self._cached_blocks
# Reset the evictor.
self.evictor = make_evictor(self.eviction_policy)
# Reset the block tracker.
for block_id in self._block_tracker:
self._block_tracker[block_id] = BlockTracker()
# Reset the metrics.
self.metric_data = CacheMetricData()
logger.info("Successfully reset prefix cache")
return True
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
return block.content_hash in self._cached_blocks
......
......@@ -136,8 +136,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
device=Device.GPU)
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
if (self.num_total_gpu_blocks - num_required_blocks
< self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
......@@ -455,6 +455,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()
def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
......
......@@ -122,6 +122,11 @@ class BlockSpaceManager(ABC):
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
pass
@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass
......@@ -90,5 +90,8 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
def reset_prefix_cache(self) -> bool:
return True
def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0
......@@ -504,6 +504,9 @@ class Scheduler:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
......@@ -985,8 +988,8 @@ class Scheduler:
waiting_queue.popleft()
continue
if (budget.num_batched_tokens >=
self.scheduler_config.max_num_batched_tokens):
if (budget.num_batched_tokens
>= self.scheduler_config.max_num_batched_tokens):
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
......@@ -1093,8 +1096,8 @@ class Scheduler:
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
......@@ -1186,8 +1189,8 @@ class Scheduler:
curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
......@@ -1355,8 +1358,8 @@ class Scheduler:
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()):
if (token_chunk_size + num_computed_tokens
< seqs[0].data.get_len()):
do_sample = False
# It assumes the scheduled_seq_groups is ordered by
......@@ -1579,6 +1582,7 @@ class Scheduler:
seq.status = SequenceStatus.WAITING
self.free_seq(seq)
seq.reset_state_for_recompute()
self._free_seq_group_cross_attn_blocks(seq_group)
def _preempt_by_swap(
self,
......@@ -1621,10 +1625,9 @@ class Scheduler:
if self.scheduler_config.delay_factor > 0 and self.waiting:
earliest_arrival_time = min(
[e.metrics.arrival_time for e in self.waiting])
passed_delay = (
(now - earliest_arrival_time) >
(self.scheduler_config.delay_factor * self.last_prompt_latency)
or not self.running)
passed_delay = ((now - earliest_arrival_time)
> (self.scheduler_config.delay_factor *
self.last_prompt_latency) or not self.running)
else:
passed_delay = True
return passed_delay
......
# cumem-based pytorch pluggable allocator to implement sleep mode.
# other approaches tried but failed:
# - cuda-python package binding
# - custom libcuda driver ctypes wrapper
# both of them failed because of cuda context mismatch.
# not sure why, they are created from a different context.
# the only successful approach is to call cuda driver API in C.
import dataclasses
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from vllm.utils import is_pin_memory_available
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found_line = None
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found_line = line
break
if found_line is None:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = found_line.index("/")
path = found_line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
return path
cumem_available = False
try:
from vllm.cumem_allocator import (init_module, python_create_and_map,
python_unmap_and_release)
from vllm.distributed.device_communicators.cuda_wrapper import (
CudaRTLibrary)
lib_name = find_loaded_library("cumem_allocator")
libcudart = CudaRTLibrary()
cumem_available = True
except ModuleNotFoundError:
# rocm platform does not support cumem allocator
init_module = None
python_create_and_map = None
python_unmap_and_release = None
CudaRTLibrary = None
lib_name = None
libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int]
@dataclasses.dataclass
class AllocationData:
handle: HandleType
tag: str
cpu_backup_tensor: Optional[torch.Tensor] = None
def create_and_map(allocation_handle: HandleType) -> None:
python_create_and_map(*allocation_handle)
def unmap_and_release(allocation_handle: HandleType) -> None:
python_unmap_and_release(*allocation_handle)
def get_pluggable_allocator(
python_malloc_fn: Callable[[int],
int], python_free_func: Callable[[int, int],
None]
) -> torch.cuda.memory.CUDAPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
lib_name, 'my_malloc', 'my_free')
return new_alloc
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[int], int],
python_free_func: Callable[[int, int], None]) -> None:
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool
class CuMemAllocator:
"""
A singleton class that manages a memory pool for CUDA tensors.
The memory in this pool can be offloaded or discarded when the
allocator sleeps.
Inside the `use_memory_pool(tag)` context, all tensors created will
be allocated in the memory pool, and has the same tag as the
tag passed to the context.
When we call `sleep`, all tensors with the specified tag will be
offloaded to CPU memory, and the rest of the tensors will be discarded.
When we call `wake_up`, all tensors that are previously offloaded
will be loaded back to GPU memory, and the rest of the tensors will
have empty memory.
Why it needs to be a singleton?
When allocated tensors are garbage collected, PyTorch will call
the free callback, which will call the `python_free_callback` method.
The C-extension uses a global variable to store the function of an
instance of this class. If we create multiple instances of this class,
the global variable will be overwritten and the free callback will
not work as expected.
"""
instance: "CuMemAllocator" = None
default_tag: str = "default"
@staticmethod
def get_instance() -> "CuMemAllocator":
"""
CuMemAllocator is a singleton class.
We cannot call the constructor directly.
Call this method to get the instance.
"""
assert cumem_available, "cumem allocator is not available"
if CuMemAllocator.instance is None:
CuMemAllocator.instance = CuMemAllocator()
return CuMemAllocator.instance
def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag)
return
def python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
data = self.pointer_to_data.pop(ptr)
if data.cpu_backup_tensor is not None:
data.cpu_backup_tensor = None
return data.handle
def sleep(
self,
offload_tags: Optional[Union[Tuple[str, ...],
str]] = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CuMemAllocator.default_tag, )
elif isinstance(offload_tags, str):
offload_tags = (offload_tags, )
assert isinstance(offload_tags, tuple)
for ptr, data in self.pointer_to_data.items():
handle = data.handle
if data.tag in offload_tags:
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device='cpu',
pin_memory=is_pin_memory_available())
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
def wake_up(self):
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory."""
for ptr, data in self.pointer_to_data.items():
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: Optional[str] = None):
"""
A context manager to use the memory pool.
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:
tag = CuMemAllocator.default_tag
assert isinstance(tag, str)
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback):
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released.
# right now it is fine, because we only use this allocator
# during weight loading and kv cache creation, where we only
# allocate memory.
# TODO: we need to find a way to release the memory,
# i.e. calling torch.cuda.empty_cache()
self.current_tag = old_tag
def get_current_usage(self) -> int:
"""
Get the total number of bytes allocated in the memory pool.
"""
sum_bytes: int = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
sum_bytes += handle[1]
return sum_bytes
from contextlib import contextmanager
from typing import Optional, Union
# ===================== import region =====================
......@@ -11,6 +10,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream
logger = init_logger(__name__)
......@@ -51,7 +51,6 @@ class PyNcclCommunicator:
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
......@@ -60,7 +59,6 @@ class PyNcclCommunicator:
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return
self.available = True
......@@ -98,12 +96,12 @@ class PyNcclCommunicator:
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.stream = torch.cuda.Stream()
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
stream.synchronize()
del data
def all_reduce(self,
......@@ -122,7 +120,7 @@ class PyNcclCommunicator:
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = self.stream
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
......@@ -144,7 +142,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
......@@ -165,7 +163,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
......@@ -180,7 +178,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
......@@ -192,7 +190,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
......@@ -204,7 +202,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
......@@ -215,27 +213,3 @@ class PyNcclCommunicator:
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
@contextmanager
def change_state(self,
enable: Optional[bool] = None,
stream: Optional[torch.cuda.Stream] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream
......@@ -247,7 +247,8 @@ class MessageQueue:
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle(),
buffer_handle=self.buffer.handle()
if self.buffer is not None else None,
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
......@@ -351,8 +352,8 @@ class MessageQueue:
sched_yield()
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
if (time.monotonic() - start_time
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
......@@ -409,8 +410,8 @@ class MessageQueue:
sched_yield()
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
if (time.monotonic() - start_time
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
......
......@@ -22,7 +22,7 @@ NOTE: If you want to not only transfer KV caches, but adjust the model execution
## Disaggregated prefilling
The example usage is in [this file](../../../examples/disaggregated_prefill.sh).
The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh).
Here is the diagram of how we run disaggretgated prefilling.
......
from typing import TYPE_CHECKING
import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type
from .base import KVConnectorBase
......@@ -7,14 +8,41 @@ if TYPE_CHECKING:
class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
@staticmethod
def create_connector(rank: int, local_rank: int,
@classmethod
def register_connector(cls, name: str, module_path: str,
class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if config.kv_transfer_config.kv_connector in supported_kv_connector:
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
return connector_cls(rank, local_rank, config)
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"PyNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
......@@ -35,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
......@@ -161,7 +162,7 @@ class SimpleConnector(KVConnectorBase):
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
num_heads = model_config.num_key_value_heads
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
......
......@@ -39,7 +39,6 @@ import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op
if TYPE_CHECKING:
......@@ -194,6 +193,7 @@ class GroupCoordinator:
assert self.cpu_group is not None
assert self.device_group is not None
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
......@@ -305,15 +305,7 @@ class GroupCoordinator:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context:
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
......@@ -365,10 +357,7 @@ class GroupCoordinator:
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
......@@ -873,12 +862,14 @@ def init_model_parallel_group(
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_pynccl=current_platform.is_cuda_alike(),
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
......@@ -920,7 +911,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
@contextmanager
def graph_capture():
def graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
......@@ -934,8 +925,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with get_tp_group().graph_capture() as context, get_pp_group(
).graph_capture(context):
context = GraphCaptureContext(torch.cuda.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context
......@@ -1022,8 +1014,8 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size):
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
......@@ -1077,8 +1069,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
return
if all([
vllm_config.kv_transfer_config.need_kv_parallel_group,
_KV_TRANSFER is None
vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
......@@ -1188,8 +1180,14 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
from vllm.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()
try:
torch._C._host_emptyCache()
except AttributeError:
logger.warning(
"torch._C._host_emptyCache() only available in Pytorch >=2.5")
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
......
......@@ -18,7 +18,6 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
......@@ -99,10 +98,8 @@ class EngineArgs:
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
......@@ -201,6 +198,10 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False
calculate_kv_scales: Optional[bool] = None
def __post_init__(self):
if not self.tokenizer:
......@@ -242,7 +243,7 @@ class EngineArgs:
choices=get_args(TaskOption),
help='The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'multiple tasks. When the model only supports one task, ``"auto"`` '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.')
parser.add_argument(
......@@ -254,7 +255,7 @@ class EngineArgs:
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer')
help='Skip initialization of tokenizer and detokenizer.')
parser.add_argument(
'--revision',
type=nullable_str,
......@@ -352,18 +353,7 @@ class EngineArgs:
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version '
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len',
type=int,
default=EngineArgs.max_model_len,
......@@ -392,7 +382,7 @@ class EngineArgs:
# Parallel arguments
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
choices=['ray', 'mp', 'uni', 'external_launcher'],
default=EngineArgs.distributed_executor_backend,
help='Backend to use for distributed model '
'workers, either "ray" or "mp" (multiprocessing). If the product '
......@@ -400,12 +390,8 @@ class EngineArgs:
'or equal to the number of GPUs available, "mp" will be used to '
'keep processing on a single host. Otherwise, this will default '
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
'and hpu only support Ray for distributed inference.')
'only supports Ray for distributed inference.')
parser.add_argument(
'--worker-use-ray',
action='store_true',
help='Deprecated, use --distributed-executor-backend=ray.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
......@@ -434,7 +420,7 @@ class EngineArgs:
choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to max-model-len. On CUDA devices, '
'set to ``--max-model-len``. On CUDA devices, '
'only block sizes up to 32 are supported. '
'On HPU devices, block size defaults to 128.')
......@@ -443,12 +429,12 @@ class EngineArgs:
action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching,
help="Enables automatic prefix caching. "
"Use --no-enable-prefix-caching to disable explicitly.",
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
)
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
'capping to sliding window size.')
parser.add_argument('--use-v2-block-manager',
action='store_true',
default=True,
......@@ -542,7 +528,7 @@ class EngineArgs:
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"rope_type":"dynamic","factor":2.0}')
'For example, ``{"rope_type":"dynamic","factor":2.0}``')
parser.add_argument('--rope-theta',
default=None,
type=float,
......@@ -611,7 +597,7 @@ class EngineArgs:
default=None,
type=json.loads,
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
'e.g., image processor. For example: ``{"num_crops": 4}``.'))
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
......@@ -879,7 +865,7 @@ class EngineArgs:
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument. Noted that this name(s) "
"same as the ``--model`` argument. Noted that this name(s) "
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics "
"tag will take the first one.")
......@@ -899,7 +885,7 @@ class EngineArgs:
default=None,
help="Valid choices are " +
",".join(ALLOWED_DETAILED_TRACE_MODULES) +
". It makes sense to set this only if --otlp-traces-endpoint is"
". It makes sense to set this only if ``--otlp-traces-endpoint`` is"
" set. If set, it will collect detailed traces for the specified "
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")
......@@ -926,13 +912,13 @@ class EngineArgs:
type=json.loads,
default=None,
help="Override or set neuron device configuration. "
"e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
"e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
parser.add_argument(
'--override-pooler-config',
type=PoolerConfig.from_json,
default=None,
help="Override or set the pooling method for pooling models. "
"e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
"e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
parser.add_argument('--compilation-config',
'-O',
......@@ -962,16 +948,43 @@ class EngineArgs:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")
"Defaults to None, no generation config is loaded, vLLM defaults "
"will be used. If set to 'auto', the generation config will be "
"loaded from model path. If set to a folder path, the generation "
"config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")
parser.add_argument(
"--override-generation-config",
type=json.loads,
default=None,
help="Overrides or sets generation config in JSON format. "
"e.g. ``{\"temperature\": 0.5}``. If used with "
"--generation-config=auto, the override parameters will be merged "
"with the default config from the model. If generation-config is "
"None, only the override parameters are used.")
parser.add_argument("--enable-sleep-mode",
action="store_true",
default=False,
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
return parser
......@@ -1002,7 +1015,6 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
......@@ -1017,7 +1029,10 @@ class EngineArgs:
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
)
def create_load_config(self) -> LoadConfig:
return LoadConfig(
......@@ -1077,11 +1092,11 @@ class EngineArgs:
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
worker_use_ray=self.worker_use_ray,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
......@@ -1111,6 +1126,7 @@ class EngineArgs:
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
from vllm.platforms import current_platform
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
......@@ -1166,7 +1182,7 @@ class EngineArgs:
num_speculative_heads=self.num_speculative_heads
)
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if self.num_scheduler_steps > 1:
if speculative_config is not None:
......@@ -1175,6 +1191,12 @@ class EngineArgs:
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1")
from vllm.platforms import current_platform
if current_platform.is_cpu():
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
"currently not supported for CPUs and has been "
"disabled.")
self.num_scheduler_steps = 1
# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step
......@@ -1285,11 +1307,22 @@ class EngineArgs:
self.enable_chunked_prefill = True
# When no user override, set the default values based on the usage
# context.
# TODO(woosuk): Tune the default values for different hardware.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
# Use different default values for different hardware.
from vllm.platforms import current_platform
device_name = current_platform.get_device_name().lower()
if "h100" in device_name or "h200" in device_name:
# For H100 and H200, we use larger default values.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192,
}
else:
# TODO(woosuk): Tune the default values for other hardware.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
......
......@@ -18,9 +18,7 @@ from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
......@@ -620,69 +618,9 @@ class AsyncLLMEngine(EngineClient):
rt.new_requests_event.set()
@classmethod
def _get_executor_cls(
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
return executor_class
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config)
@classmethod
def from_engine_args(
......@@ -700,9 +638,6 @@ class AsyncLLMEngine(EngineClient):
executor_class = cls._get_executor_cls(engine_config)
if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)
# Create the async LLM engine.
engine = cls(
vllm_config=engine_config,
......@@ -1242,20 +1177,16 @@ class AsyncLLMEngine(EngineClient):
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
self.engine.start_profile()
async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
self.engine.stop_profile()
async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()
async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)
# TODO(v1): Remove this class proxy when V1 goes default.
......
......@@ -29,8 +29,6 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
......@@ -233,7 +231,7 @@ class LLMEngine:
)
logger.info(
"Initializing an LLM engine (v%s) with config: %s, "
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
VLLM_VERSION,
vllm_config,
......@@ -445,64 +443,31 @@ class LLMEngine:
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
elif engine_config.parallel_config.world_size > 1:
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutor
executor_class = RayHPUExecutor
else:
from vllm.executor.hpu_executor import HPUExecutor
executor_class = HPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
return executor_class
@classmethod
......@@ -727,7 +692,9 @@ class LLMEngine:
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
lora_request: The LoRA request to add.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: The prompt adapter request to add.
priority: The priority of the request.
Only applicable with priority scheduling.
......@@ -950,6 +917,14 @@ class LLMEngine:
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
return success
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
......@@ -1038,8 +1013,23 @@ class LLMEngine:
self.speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(seq_group_metadata_list))
if self.scheduler_config.is_multi_step:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, len(seq_group_metadata_list))
elif self.speculative_config:
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger chunk
# processing without having to pad or copy over prompts K times
# to match decodes structure (costly with prompt_logprobs).
num_prefills = sum(sg.is_prompt
for sg in seq_group_metadata_list)
prefills, decodes = outputs[:num_prefills], outputs[
num_prefills:]
outputs_by_sequence_group = create_output_by_sequence_group(
decodes,
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
outputs_by_sequence_group = [p.outputs for p in prefills
] + outputs_by_sequence_group
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
......@@ -1141,6 +1131,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
......@@ -1183,6 +1175,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
......@@ -1703,7 +1697,7 @@ class LLMEngine:
# If the seq_group just finished the prefill state
# get TTFT.
if not seq_group.is_prefill():
latency = seq_group.get_last_latency(now)
latency = seq_group.get_last_token_latency()
time_to_first_tokens_iter.append(latency)
# One generation token per finished prefill.
......@@ -1711,7 +1705,7 @@ class LLMEngine:
seq_group.num_seqs())
else:
# TPOTs.
latency = seq_group.get_last_latency(now)
latency = seq_group.get_last_token_latency()
time_per_output_tokens_iter.append(latency)
if seq_group.state.current_step == 0:
# For async_output_proc, the do_log_stats()
......@@ -1858,27 +1852,27 @@ class LLMEngine:
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def start_profile(self) -> None:
self.model_executor.start_profile()
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def sleep(self, level: int = 1) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config")
self.model_executor.sleep(level=level)
def wake_up(self) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config")
self.model_executor.wake_up()
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()
def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.start_profile()
else:
self.model_executor._run_workers("start_profile")
def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor: # noqa: E721
self.model_executor.stop_profile()
else:
self.model_executor._run_workers("stop_profile")
def is_tracing_enabled(self) -> bool:
return self.tracer is not None
......@@ -1913,46 +1907,44 @@ class LLMEngine:
metrics = seq_group.metrics
ttft = metrics.first_token_time - metrics.arrival_time
e2e_time = metrics.finished_time - metrics.arrival_time
# attribute names are based on
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
self.model_config.model)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
seq_group.request_id)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
seq_group.sampling_params.temperature)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES,
seq_group.num_seqs())
seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(seq_group.prompt_token_ids))
seq_span.set_attribute(
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
sum([
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
]))
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
metrics.time_in_queue)
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
if metrics.scheduler_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER,
metrics.scheduler_time)
if metrics.model_forward_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD,
metrics.model_forward_time / 1000.0)
if metrics.model_execute_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
def _validate_model_inputs(self, inputs: ProcessorInputs,
......
......@@ -120,7 +120,8 @@ class Metrics:
labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.capture_sizes.copy()
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total",
......@@ -258,21 +259,6 @@ class Metrics:
documentation="Number of emitted tokens.",
labelnames=labelnames))
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._gauge_cls(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
multiprocess_mode="sum",
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
multiprocess_mode="sum",
)
# end-metrics-definitions
......@@ -634,20 +620,6 @@ class PrometheusStatLogger(StatLoggerBase):
self._log_histogram(self.metrics.histogram_max_tokens_request,
stats.max_tokens_requests)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on
# the vLLM side. Moving forward, we should use counters like
# counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)
def log(self, stats: Stats):
"""Logs to prometheus and tracked stats every iteration."""
# Log to prometheus.
......@@ -663,20 +635,6 @@ class PrometheusStatLogger(StatLoggerBase):
# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = get_throughput(self.num_prompt_tokens,
now=stats.now,
last_log=self.last_local_log)
generation_throughput = get_throughput(
self.num_generation_tokens,
now=stats.now,
last_log=self.last_local_log)
self._log_prometheus_interval(
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)
if self.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
......
from dataclasses import dataclass
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Mapping, Optional, Union, overload
......@@ -120,10 +121,28 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2
class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@dataclass
class RPCAdapterLoadedResponse:
request_id: str
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest]
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]
def ENGINE_DEAD_ERROR(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment