Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
......@@ -6,7 +6,6 @@ import enum
import hashlib
import inspect
import json
import re
import textwrap
import uuid
import warnings
......@@ -20,6 +19,7 @@ from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
import regex as re
import torch
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
......@@ -42,7 +42,10 @@ from vllm.transformers_utils.config import (
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
......@@ -64,12 +67,6 @@ logger = init_logger(__name__)
ConfigT = TypeVar("ConfigT", bound=ConfigType)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]
......@@ -536,13 +533,19 @@ class ModelConfig:
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
# when Gemma 2 is fixed in Transformers.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
(self.hf_text_config.model_type in interleaved_attn_models))
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
has_interleaved_attention = sliding_window_pattern is not None or (
isinstance(sliding_window, list))
if (not self.disable_sliding_window and has_interleaved_attention):
if not self.disable_sliding_window and has_interleaved_attention:
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
......@@ -562,7 +565,10 @@ class ModelConfig:
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
if hasattr(self.hf_text_config, "sliding_window"):
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
self.max_model_len = _get_and_verify_max_len(
......@@ -824,7 +830,7 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "nvfp4", "bitblas", "gptq_bitblas"
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = cast(QuantizationMethods,
......@@ -987,7 +993,7 @@ class ModelConfig:
self.use_async_output_proc = False
return
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
from vllm.platforms import current_platform
if not current_platform.is_async_output_supported(self.enforce_eager):
......@@ -1003,7 +1009,7 @@ class ModelConfig:
if self.runner_type == "pooling":
self.use_async_output_proc = False
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if speculative_config:
self.use_async_output_proc = False
......@@ -2074,28 +2080,28 @@ class SchedulerConfig:
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
self.max_num_batched_tokens = (
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
# When using default settings,
......@@ -2201,7 +2207,11 @@ class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: Union[Device, torch.device] = "auto"
"""Device type for vLLM execution."""
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
......@@ -2251,7 +2261,7 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model"]
"draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
......@@ -2515,6 +2525,15 @@ class SpeculativeConfig:
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type ==
"deepseek_mtp"):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
......@@ -2525,11 +2544,10 @@ class SpeculativeConfig:
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
EAGLEConfig) or current_platform.is_neuron():
EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(
......@@ -2735,7 +2753,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3")
return self.method in ("eagle", "eagle3", "deepseek_mtp")
def __repr__(self) -> str:
method = self.method
......@@ -2968,7 +2986,7 @@ class PoolerConfig:
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model. This should be a key in
{class}`vllm.model_executor.layers.pooler.PoolingType`.
[`vllm.model_executor.layers.pooler.PoolingType`][].
"""
normalize: Optional[bool] = None
......@@ -3491,7 +3509,7 @@ class KVTransferConfig:
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: str = str(uuid.uuid4())
engine_id: Optional[str] = None
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
......@@ -3548,6 +3566,9 @@ class KVTransferConfig:
return hash_str
def __post_init__(self) -> None:
if self.engine_id is None:
self.engine_id = str(uuid.uuid4())
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are {get_args(KVRole)}")
......@@ -3646,6 +3667,8 @@ class PassConfig:
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
"""Whether to enable sequence parallelism."""
enable_async_tp: bool = False
"""Whether to enable async TP."""
def uuid(self):
"""
......@@ -3655,7 +3678,8 @@ class PassConfig:
compilation.
"""
include = {
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
"enable_async_tp"
}
dict_ = {k: v for k, v in asdict(self).items() if k in include}
return InductorPass.hash_dict(dict_)
......@@ -3673,23 +3697,27 @@ class CompilationConfig:
"""Configuration for compilation. It has three parts:
- Top-level Compilation control:
- {attr}`level`
- {attr}`debug_dump_path`
- {attr}`cache_dir`
- {attr}`backend`
- {attr}`custom_ops`
- {attr}`splitting_ops`
- [`level`][vllm.config.CompilationConfig.level]
- [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
- [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
- [`backend`][vllm.config.CompilationConfig.backend]
- [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- {attr}`use_cudagraph`
- {attr}`cudagraph_capture_sizes`
- {attr}`cudagraph_num_of_warmups`
- {attr}`cudagraph_copy_inputs`
- {attr}`full_cuda_graph`
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`]
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
- Inductor compilation:
- {attr}`use_inductor`
- {attr}`compile_sizes`
- {attr}`inductor_compile_config`
- {attr}`inductor_passes`
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
- custom inductor passes
Why we have different sizes for cudagraph and inductor:
......@@ -4268,6 +4296,12 @@ class VllmConfig:
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = \
True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
......@@ -4312,18 +4346,6 @@ class VllmConfig:
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
......@@ -4549,7 +4571,7 @@ def contains_object_print(text):
text (str): The text to check
Returns:
bool: True if a match is found, False otherwise
result (bool): `True` if a match is found, `False` otherwise.
"""
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
match = re.search(pattern, text)
......
......@@ -167,4 +167,7 @@ class HTTPConnection:
global_http_connection = HTTPConnection()
"""The global {class}`HTTPConnection` instance used by vLLM."""
"""
The global [`HTTPConnection`][vllm.connections.HTTPConnection] instance used
by vLLM.
"""
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from .base_device_communicator import All2AllManagerBase, Cache
class All2AllBase:
def __init__(self, cpu_group, model):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_ep_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
logger = init_logger(__name__)
def destroy(self):
pass
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None
class NaiveAll2All(All2AllBase):
class NaiveAll2AllManager(All2AllManagerBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
......@@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase):
debugging.
"""
def __init__(self, cpu_group, model):
super().__init__(cpu_group, model)
def __init__(self, cpu_group):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
......@@ -91,3 +71,56 @@ class NaiveAll2All(All2AllBase):
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
# SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional
from weakref import WeakValueDictionary
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance
class All2AllManagerBase:
def __init__(self, cpu_group):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
......@@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
......@@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
if not self.use_all2all:
return
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)
def dispatch(
self, hidden_states: torch.Tensor,
......
......@@ -22,8 +22,10 @@ class CpuCommunicator(DeviceCommunicatorBase):
super().__init__(cpu_group, device, device_group, unique_name)
self.dist_module = torch.distributed
if (current_platform.get_cpu_architecture() == CpuArchEnum.X86) \
and hasattr(torch.ops._C, "init_shm_manager"):
if (current_platform.get_cpu_architecture()
== CpuArchEnum.X86) and hasattr(
torch.ops._C,
"init_shm_manager") and unique_name.startswith("tp"):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(self, input_):
......@@ -96,6 +98,8 @@ class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
......
......@@ -6,10 +6,12 @@ import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .all2all import All2AllBase
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CudaCommunicator(DeviceCommunicatorBase):
......@@ -31,8 +33,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
......@@ -56,6 +56,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
......@@ -136,31 +149,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_impl is not None:
self.all2all_impl.destroy()
self.all2all_impl = None
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None
hidden_states = self.all2all_impl.combine(hidden_states)
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
......@@ -19,7 +17,7 @@ from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
......@@ -28,20 +26,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger = init_logger(__name__)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))
def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
class ShmRingBuffer:
......
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, get_kv_transfer_group,
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
has_kv_transfer_group, is_v1_kv_transfer_group)
__all__ = [
......
......@@ -31,12 +31,12 @@ class MooncakeStoreConnector(KVConnectorBase):
local_rank: int,
config: VllmConfig,
):
self.config = config.kv_transfer_config
self.kv_transfer_config = config.kv_transfer_config
self.kv_helper = kv_helper(config)
self.local_tp_rank = local_rank
# Init kv_store
if self.config.kv_connector == "MooncakeStoreConnector":
if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector":
# Check if MOONCAKE_CONFIG_PATH is set
import os
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
......@@ -50,10 +50,11 @@ class MooncakeStoreConnector(KVConnectorBase):
MooncakeStore)
logger.info(
"Initializing KVStoreConnector under kv_transfer_config %s",
self.config)
self.kv_transfer_config)
self.kv_store = MooncakeStore(config)
else:
logger.error("Can not find %s", self.config.kv_connector)
logger.error("Can not find %s",
self.kv_transfer_config.kv_connector)
assert self.kv_store is not None
......
......@@ -106,7 +106,7 @@ class SimpleConnector(KVConnectorBase):
else:
# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
# its recv pipe to the send pipe of KV producer
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
......
......@@ -44,8 +44,9 @@ class model_aware_kv_ops_helper:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
head_size = getattr(model_config, "head_dim", None)
if head_size is None:
head_size = int(hidden_size // num_attention_heads)
return num_heads, head_size
......
......@@ -210,9 +210,10 @@ class KVConnectorBase_V1(ABC):
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
A tuple with the following elements:
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if external KV cache tokens will be loaded
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
pass
......
......@@ -40,7 +40,7 @@ class MultiConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors = []
self._connectors: list[KVConnectorBase_V1] = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors")
assert ktcs is not None
......
......@@ -259,6 +259,15 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
if not block_ids:
logger.debug(
"Skipping adding request %s to NixlConnectorMetadata, "
"as there are no remote blocks to pull", req_id)
continue
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
......@@ -528,6 +537,7 @@ class NixlConnectorWorker:
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
engine_id = nixl_agent_meta.engine_id
assert engine_id != self.engine_id, "Conflict engine id found!"
if engine_id in self._remote_agents:
return
......
......@@ -118,11 +118,11 @@ class PyNcclPipe(KVPipeBase):
"""
Create the metadata as a dictionary based on the input tensor.
Parameters:
- tensor: The input tensor or None if no tensor is provided.
Args:
tensor: The input tensor or None if no tensor is provided.
Returns:
- metadata: A dictionary with the following keys:
metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
......@@ -135,13 +135,13 @@ class PyNcclPipe(KVPipeBase):
"""
Create a buffer to receive the tensor based on the provided metadata.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.
Args:
metadata: A dictionary with keys "dtype" and "shape",
describing the tensor's data type and shape.
Returns:
- buffer: A tensor of the specified type and shape, allocated on
self.device.
buffer: A tensor of the specified type and shape,
allocated on `self.device`.
"""
return torch.empty(metadata["shape"],
dtype=metadata["dtype"],
......@@ -151,8 +151,8 @@ class PyNcclPipe(KVPipeBase):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
Args:
metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
......@@ -161,8 +161,8 @@ class PyNcclPipe(KVPipeBase):
Receive the metadata dictionary from the target rank.
Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
metadata: A dictionary with keys "dtype" and "shape"
describing the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
......@@ -171,8 +171,8 @@ class PyNcclPipe(KVPipeBase):
The actual implementation of sending the tensor and its metadata to the
target rank.
Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
Args:
tensor: The input tensor to be sent, or `None` if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
......@@ -187,7 +187,7 @@ class PyNcclPipe(KVPipeBase):
the target rank.
Returns:
- buffer: The received tensor, or None if no tensor is received.
buffer: The received tensor, or `None` if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
......@@ -227,8 +227,8 @@ class PyNcclPipe(KVPipeBase):
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
Args:
tensor: The tensor to send, or `None` if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
......@@ -250,8 +250,8 @@ class PyNcclPipe(KVPipeBase):
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
Args:
tensor: The received tensor, or `None` if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
......
......@@ -23,7 +23,6 @@ If you only need to use the distributed environment without model/pipeline
"""
import contextlib
import gc
import importlib.util
import pickle
import weakref
from collections import namedtuple
......@@ -43,7 +42,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
run_once, supports_custom_op)
supports_custom_op)
@dataclass
......@@ -120,7 +119,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)
return group._reduce_scatter_out_place(tensor, dim)
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
......@@ -136,7 +135,7 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.all_gather(tensor, dim)
return group._all_gather_out_place(tensor, dim)
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
......@@ -161,6 +160,7 @@ if supports_custom_op():
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
......@@ -168,6 +168,7 @@ if supports_custom_op():
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
dispatch_key=current_platform.dispatch_key,
)
......@@ -367,6 +368,16 @@ class GroupCoordinator:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if self.use_custom_op_call:
return torch.ops.vllm.all_gather(input_,
dim,
world_size,
group_name=self.unique_name)
else:
return self._all_gather_out_place(input_, dim)
def _all_gather_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor:
return self.device_communicator.all_gather(input_, dim)
def reduce_scatter(self,
......@@ -379,6 +390,16 @@ class GroupCoordinator:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if self.use_custom_op_call:
return torch.ops.vllm.reduce_scatter(input_,
dim,
world_size,
group_name=self.unique_name)
else:
return self._reduce_scatter_out_place(input_, dim)
def _reduce_scatter_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor:
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self,
......@@ -769,10 +790,14 @@ class GroupCoordinator:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
else:
return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
else:
return hidden_states
_WORLD: Optional[GroupCoordinator] = None
......@@ -937,49 +962,9 @@ def init_distributed_environment(
"world group already initialized with a different world size")
PPLX_DID_INIT: bool = False
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""
......@@ -1082,14 +1067,10 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)
if enable_expert_parallel:
pplx_init(rank, world_size)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
......@@ -1100,8 +1081,7 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size,
enable_expert_parallel, backend)
pipeline_model_parallel_size, backend)
return
assert (
......@@ -1180,8 +1160,6 @@ def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
pplx_finalize()
if _TP:
_TP.destroy()
_TP = None
......@@ -1221,8 +1199,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
ray.shutdown()
gc.collect()
from vllm.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()
empty_cache = current_platform.empty_cache
if empty_cache is not None:
empty_cache()
try:
torch._C._host_emptyCache()
except AttributeError:
......
......@@ -6,9 +6,12 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import datetime
import os
import pickle
import socket
import sys
import time
import uuid
from collections import deque
from collections.abc import Sequence
from typing import Any, Optional
......@@ -27,6 +30,20 @@ from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
logger = init_logger(__name__)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))
def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
......@@ -212,10 +229,141 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
Uses a multi-phase approach to ensure all processes reach the barrier
before proceeding:
1. Each process signals it has reached the barrier
2. Each process signals that it has confirmed the arrival of all other
ranks.
3. Rank 0 waits for all other ranks to signal their departure to ensure
that all ranks have departed the barrier first.
Args:
timeout: Maximum time in seconds to wait for each phase (in seconds)
Raises:
RuntimeError: If coordination fails or times out
"""
# Generate a barrier ID that is globally unique
try:
if self.rank == 0:
barrier_id = f"barrier_{uuid.uuid4()}"
self.broadcast_obj(barrier_id, src=0)
else:
barrier_id = self.broadcast_obj(None, src=0)
except Exception as e:
raise RuntimeError("Failed to broadcast barrier_id") from e
# Phase 1: Signal arrival at barrier
# Wait for all processes to arrive
# We need all ranks to confirm the arrival of all other ranks.
# This is the key synchronization point.
arrival_key = f"arrival_{barrier_id}_{self.rank}"
try:
self.store.set(arrival_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier arrival") from e
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < self.world_size:
# Check for timeout
cur_time = time.time()
if cur_time - start_time > timeout:
raise RuntimeError("Barrier timed out after %f seconds",
timeout)
# Check for each process
for i in range(self.world_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_arrived.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()
# Short sleep to avoid tight polling
if len(processes_arrived) < self.world_size:
sched_yield()
# Phase 2: Signal departure from barrier
# We only care to block at this stage in rank 0, which runs the
# server side of the TCPStore. We want to make sure that all
# clients have departed the barrier before rank 0 in case the
# next thing after the barrier is a shutdown, including tearing
# down the TCPStore. Other ranks can exit the barrier immediately
# after signaling their departure.
departure_key = f"departure_{barrier_id}_{self.rank}"
try:
self.store.set(departure_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier departure") from e
if self.rank != 0:
return
# Make rank 0 wait for all processes to signal departure
start_time = time.time()
processes_departed: set[int] = set()
while len(processes_departed) < self.world_size:
# Check for timeout
if time.time() - start_time > timeout:
raise RuntimeError("Barrier departure timed out after %f s",
timeout)
# Check for each process
for i in range(self.world_size):
if i in processes_departed:
continue
key = f"departure_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_departed.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()
# Short sleep to avoid tight polling
if len(processes_departed) < self.world_size:
sched_yield()
# Clean up keys to avoid leaking memory in the store
for i in range(self.world_size):
self.broadcast_obj(None, src=i)
try:
self.store.delete_key(f"arrival_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s",
f'arrival_{barrier_id}_{i}')
try:
self.store.delete_key(f"departure_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s",
f'departure_{barrier_id}_{i}')
@staticmethod
def create(
......
......@@ -4,7 +4,6 @@
import argparse
import dataclasses
import json
import re
import sys
import threading
import warnings
......@@ -13,6 +12,7 @@ from itertools import permutations
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
import regex as re
import torch
from typing_extensions import TypeIs, deprecated
......@@ -577,7 +577,7 @@ class EngineArgs:
action=argparse.BooleanOptionalAction,
deprecated=True,
help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
"of v0.8.6. Use `--reasoning-parser` to specify the reasoning "
"of v0.9.0. Use `--reasoning-parser` to specify the reasoning "
"parser backend instead. This flag (`--enable-reasoning`) will be "
"removed in v0.10.0. When `--reasoning-parser` is specified, "
"reasoning mode is automatically enabled.")
......@@ -737,7 +737,9 @@ class EngineArgs:
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device", **device_kwargs["device"])
device_group.add_argument("--device",
**device_kwargs["device"],
deprecated=True)
# Speculative arguments
speculative_group = parser.add_argument_group(
......@@ -977,7 +979,7 @@ class EngineArgs:
from vllm.platforms import current_platform
current_platform.pre_register_and_update()
device_config = DeviceConfig(device=self.device)
device_config = DeviceConfig(device=current_platform.device_type)
model_config = self.create_model_config()
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
......@@ -1082,7 +1084,7 @@ class EngineArgs:
disable_log_stats=self.disable_log_stats,
)
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if self.num_scheduler_steps > 1:
if speculative_config is not None:
......@@ -1193,8 +1195,7 @@ class EngineArgs:
#############################################################
# Unsupported Feature Flags on V1.
if (self.load_format == LoadFormat.TENSORIZER.value
or self.load_format == LoadFormat.SHARDED_STATE.value):
if self.load_format == LoadFormat.SHARDED_STATE.value:
_raise_or_fallback(
feature_name=f"--load_format {self.load_format}",
recommend_to_remove=False)
......@@ -1290,14 +1291,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# Some quantization is not compatible with torch.compile.
V1_UNSUPPORTED_QUANT = ["gguf"]
if model_config.quantization in V1_UNSUPPORTED_QUANT:
_raise_or_fallback(
feature_name=f"--quantization {model_config.quantization}",
recommend_to_remove=False)
return False
# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
......@@ -1337,7 +1330,7 @@ class EngineArgs:
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3"):
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
......
......@@ -475,7 +475,8 @@ class _AsyncLLMEngine(LLMEngine):
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of {meth}`add_request`."""
"""Async version of
[`add_request`][vllm.engine.llm_engine.LLMEngine.add_request]."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
......@@ -582,20 +583,21 @@ async def build_guided_decoding_logits_processor_async(
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for {class}`LLMEngine`.
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
This class is used to wrap the {class}`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The {class}`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the {class}`LLMEngine` to the caller.
This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
make it asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
by the generate method when there are requests in the waiting queue. The
generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
to the caller.
Args:
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for {class}`LLMEngine`.
**kwargs: Arguments for {class}`LLMEngine`.
*args: Arguments for [`LLMEngine`][vllm.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
......@@ -985,8 +987,9 @@ class AsyncLLMEngine(EngineClient):
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
for more details about the format of each input.
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
......@@ -1003,7 +1006,7 @@ class AsyncLLMEngine(EngineClient):
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
{meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
[`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
......@@ -1075,8 +1078,9 @@ class AsyncLLMEngine(EngineClient):
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
for more details about the format of each input.
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
......@@ -1091,7 +1095,7 @@ class AsyncLLMEngine(EngineClient):
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
{meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
[`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
......
......@@ -130,26 +130,16 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The {class}`~vllm.LLM` class wraps this class for offline batched inference
and the {class}`AsyncLLMEngine` class wraps this class for online serving.
The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
class wraps this class for online serving.
The config arguments are derived from {class}`~vllm.EngineArgs`. (See
{ref}`engine-args`)
The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
speculative_config (Optional): The configuration related to speculative
decoding.
vllm_config: The configuration for initializing and running vLLM.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
"""
......@@ -695,11 +685,12 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
prompt: The prompt to the LLM. See
[PromptType][vllm.inputs.PromptType]
for more details about the format of each input.
params: Parameters for sampling or pooling.
{class}`~vllm.SamplingParams` for text generation.
{class}`~vllm.PoolingParams` for pooling.
[SamplingParams][vllm.SamplingParams] for text generation.
[PoolingParams][vllm.PoolingParams] for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
lora_request: The LoRA request to add.
......@@ -711,10 +702,11 @@ class LLMEngine:
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `n` number of {class}`~vllm.Sequence` objects.
- Create a {class}`~vllm.SequenceGroup` object
from the list of {class}`~vllm.Sequence`.
- Add the {class}`~vllm.SequenceGroup` object to the scheduler.
- Create `n` number of [Sequence][vllm.Sequence] objects.
- Create a [SequenceGroup][vllm.SequenceGroup] object
from the list of [Sequence][vllm.Sequence].
- Add the [SequenceGroup][vllm.SequenceGroup] object to the
scheduler.
Example:
>>> # initialize engine
......@@ -861,9 +853,7 @@ class LLMEngine:
request_id: The ID(s) of the request to abort.
Details:
- Refer to the
{meth}`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class {class}`~vllm.core.scheduler.Scheduler`.
- Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
Example:
>>> # initialize engine and add a request with request_id
......@@ -1263,12 +1253,10 @@ class LLMEngine:
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
:::{figure} https://i.imgur.com/sv2HssD.png
:alt: Overview of the step function
:align: center
Overview of the step function.
:::
<figure markdown="span">
![Overview of the step function](https://i.imgur.com/sv2HssD.png)
<figcaption>Overview of the step function</figcaption>
</figure>
Details:
- Step 1: Schedules the sequences to be executed in the next
......@@ -1662,6 +1650,20 @@ class LLMEngine:
gpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.GPU)
# Exchange the uasge and cache hit stats between gpu and cpu when
# running on cpu because the cpu_worker.py intentionally reports the
# number of cpu blocks as gpu blocks in favor of cache management.
if self.device_config.device_type == "cpu":
num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu
gpu_cache_usage_sys, cpu_cache_usage_sys = (
cpu_cache_usage_sys,
gpu_cache_usage_sys,
)
gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = (
cpu_prefix_cache_hit_rate,
gpu_prefix_cache_hit_rate,
)
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment