Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
......@@ -4,6 +4,7 @@ import hashlib
import inspect
import json
import types
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Union
import torch
......@@ -18,6 +19,34 @@ else:
from .torch25_custom_graph_pass import ( # noqa: yapf
Torch25CustomGraphPass as CustomGraphPass)
_pass_context = None
class PassContext:
def __init__(self, runtime_shape: Optional[int]):
self.runtime_shape = runtime_shape
def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context
@contextmanager
def pass_context(runtime_shape: Optional[int]):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(runtime_shape)
try:
yield
finally:
_pass_context = prev_context
class InductorPass(CustomGraphPass):
"""
......@@ -62,6 +91,9 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_shape(self, shape: Optional[int]):
return True
class CallableInductorPass(InductorPass):
"""
......
......@@ -4,13 +4,15 @@ from typing import List
from torch import fx as fx
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import CustomGraphPass, InductorPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
......@@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: List[InductorPass] = []
self.passes: List[VllmInductorPass] = []
def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
pass_(graph)
if pass_.is_applicable_for_shape(shape):
pass_(graph)
# always run fix_functionalization last
self.fix_functionalization(graph)
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.fix_functionalization = FixFunctionalizationPass(pass_config)
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class AllReduceRMSNormPattern:
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
self.epsilon = epsilon
self.dtype = dtype
self.device = device
class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern):
def get_inputs(self):
arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype)
mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]],
device=self.device,
dtype=torch.long)
unsqueeze = torch.rand([1, 8, 1], device=self.device, \
dtype=self.dtype) > 0.5
full_default = torch.zeros([1, 8, 4], device=self.device, \
dtype=self.dtype)
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
arg2_1: torch.Tensor,
mul_6: torch.Tensor,
unsqueeze: torch.Tensor,
full_default: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
where = torch.ops.aten.where.self(unsqueeze, full_default,
embedding)
all_reduce = tensor_model_parallel_all_reduce(where)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=permute,
input=all_reduce,
weight=arg3_1,
epsilon=self.epsilon,
)
return rmsnorm[1], all_reduce
def replacement(
arg2_1: torch.Tensor,
mul_6: torch.Tensor,
unsqueeze: torch.Tensor,
full_default: torch.Tensor,
permute: torch.Tensor,
arg3_1: torch.Tensor,
):
embedding = torch.ops.aten.embedding.default(arg2_1, mul_6)
where = torch.ops.aten.where.self(unsqueeze, full_default,
embedding)
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
rmsnorm_result = torch.empty_like(reduce_scatter)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.rms_norm.default,
result=rmsnorm_result,
input=reduce_scatter,
weight=arg3_1,
epsilon=self.epsilon,
)
all_gather = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
return all_gather, reduce_scatter
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=all_reduce,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
return rmsnorm[1], rmsnorm[2]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
# TODO is it possible to extract epsilon from somewhere
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduce_scatter,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
all_gather = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
return all_gather, rmsnorm[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
def get_inputs(self):
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
rms_norm_weights = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
return [
residual,
mm_1,
rms_norm_weights,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=all_reduce,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
return rmsnorm[1]
def replacement(
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name)
# TODO is it possible to extract epsilon from somewhere
rmsnorm = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduce_scatter,
residual=residual,
weight=rms_norm_weights,
epsilon=self.epsilon,
)
normalized = torch.ops.vllm.all_gather.default(
rmsnorm[1],
dim=0,
world_size=tp_size,
group_name=tp.unique_name)
return normalized
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="sequence_parallelism_pass")
for epsilon in [1e-5, 1e-6]:
EmbeddingAllReduceRMSNormPattern(
epsilon, self.dtype, self.device).register(self.patterns)
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
self.device).register(self.patterns)
LastAllReduceRMSNormPattern(epsilon, self.dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
def __call__(self, graph: fx.Graph):
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
......@@ -4,7 +4,7 @@ import time
import torch
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
......@@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass):
It provides timing, logging, and dumping utilities.
"""
def __init__(self, config: CompilationConfig.PassConfig):
self.config = config
def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
self.dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config \
else None
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
if stage in self.config.dump_graph_stages or always:
if stage in self.pass_config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py"
logger.info("%s printing graph to %s", self.pass_name, filepath)
with open(filepath, "w") as f:
......
......@@ -6,18 +6,18 @@ import enum
import hashlib
import inspect
import json
import re
import sys
import textwrap
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union)
Optional, Protocol, TypeVar, Union, get_args)
import torch
from pydantic import BaseModel, Field, PrivateAttr
......@@ -28,6 +28,7 @@ import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
......@@ -52,16 +53,16 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
Config = TypeVar("Config", bound=DataclassInstance)
ConfigType = type[DataclassInstance]
else:
QuantizationConfig = None
Config = TypeVar("Config")
ConfigType = type
logger = init_logger(__name__)
ConfigT = TypeVar("ConfigT", bound=ConfigType)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
......@@ -121,7 +122,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
......@@ -163,7 +164,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
def config(cls: type[Config]) -> type[Config]:
def config(cls: ConfigT) -> ConfigT:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
......@@ -182,6 +183,23 @@ def config(cls: type[Config]) -> type[Config]:
return cls
def get_field(cls: ConfigType, name: str) -> Field:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields.get(name)
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig:
"""Configuration for the model.
......@@ -250,7 +268,7 @@ class ModelConfig:
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
......@@ -298,12 +316,20 @@ class ModelConfig:
factors.append(self.quantization)
factors.append(self.revision)
factors.append(self.code_revision)
factors.append(self.max_model_len)
factors.append(self.max_logprobs)
factors.append(self.disable_sliding_window)
factors.append(self.trust_remote_code)
factors.append(self.mm_processor_kwargs)
factors.append(self.generation_config)
factors.append(self.model_impl)
factors.append(self.override_generation_config)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
# rope cos/sin cache depends on the max_position_embeddings
factors.append(
getattr(self.hf_config, "max_position_embeddings", "None"))
# hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string())
str_factors = str(factors)
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(
......@@ -332,7 +358,7 @@ class ModelConfig:
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
......@@ -417,8 +443,10 @@ class ModelConfig:
from vllm.platforms import current_platform
if self.enable_sleep_mode and not current_platform.is_cuda():
raise ValueError("Sleep mode is only supported on CUDA devices.")
if (self.enable_sleep_mode
and not current_platform.is_sleep_mode_available()):
raise ValueError(
"Sleep mode is not supported on current platform.")
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
......@@ -553,7 +581,7 @@ class ModelConfig:
self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
self, limit_mm_per_prompt: Optional[dict[str, int]]
) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
......@@ -725,8 +753,8 @@ class ModelConfig:
supported_quantization = QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark", "nvfp4"
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "nvfp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......@@ -736,13 +764,47 @@ class ModelConfig:
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
"marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"ipex",
"moe_wna16",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
]
# Any custom overrides will be in quantization_methods so we place
# them at the start of the list so custom overrides have preference
# over the built in ones.
quantization_methods = quantization_methods + overrides
# Detect which checkpoint is it
for name in QUANTIZATION_METHODS:
for name in quantization_methods:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
if quantization_override is not None:
# Raise error if the override is not custom (custom would
# be in QUANTIZATION_METHODS but not QuantizationMethods)
# and hasn't been added to the overrides list.
if (name in get_args(QuantizationMethods)
and name not in overrides):
raise ValueError(
f"Quantization method {name} is an override but "
"is has not been added to the `overrides` list "
"above. This is necessary to ensure that the "
"overrides are checked in order of preference.")
quant_method = quantization_override
self.quantization = quantization_override
break
......@@ -1220,23 +1282,78 @@ class ModelConfig:
return (hasattr(self.hf_config, "matryoshka_dimensions")
or getattr(self.hf_config, "is_matryoshka", False))
@property
def matryoshka_dimensions(self):
return getattr(self.hf_config, "matryoshka_dimensions", None)
class CacheConfig:
"""Configuration for the KV cache.
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
is_attention_free: Whether the model is attention-free.
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
sliding_window: Sliding window size for the KV cache.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
@config
@dataclass
class CacheConfig:
"""Configuration for the KV cache."""
block_size: BlockSize = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_configs()` based on the current
platform."""
gpu_memory_utilization: float = 0.9
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
utilization. If unspecified, will use the default value of 0.9. This is a
per-instance limit, and only applies to the current vLLM instance. It does
not matter if you have another vLLM instance running on the same GPU. For
example, if you have two vLLM instances running on the same GPU, you can
set the GPU memory utilization to 0.5 for each instance."""
swap_space: float = 4
"""Size of the CPU swap space per GPU (in GiB)."""
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3)."""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
num_gpu_blocks_override: Optional[int] = None
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
if specified. Does nothing if `None`. Used for testing preemption."""
sliding_window: Optional[int] = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: Optional[bool] = None
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
default for V1."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads."""
cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""
# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for GPU memory."""
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""
def compute_hash(self) -> str:
"""
......@@ -1257,43 +1374,13 @@ class CacheConfig:
usedforsecurity=False).hexdigest()
return hash_str
def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: float,
cache_dtype: str,
is_attention_free: bool = False,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
prefix_caching_hash_algo: str = "builtin",
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * GiB_bytes
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.prefix_caching_hash_algo = prefix_caching_hash_algo
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
# Will be set after profiling.
self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None
# Set calculate_kv_scales to False if the value is unset.
if self.calculate_kv_scales is None:
self.calculate_kv_scales = False
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
......@@ -1312,7 +1399,7 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
elif self.cache_dtype in get_args(CacheDType):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
......@@ -1330,12 +1417,12 @@ class CacheConfig:
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
"builtin", "sha256"):
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
not in get_args(PrefixCachingHashAlgo)):
raise ValueError(
"Unknown prefix caching hash algorithm: "
f"{self.prefix_caching_hash_algo}. Must be either "
"'builtin' or 'sha256'.")
f"{self.prefix_caching_hash_algo}. Must be one of "
f"{get_args(PrefixCachingHashAlgo)}.")
def verify_with_parallel_config(
self,
......@@ -1356,77 +1443,33 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
@config
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
"""This config is deprecated and will be removed in a future release.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
Passing these parameters will have no effect. Please remove them from your
configurations.
"""
pool_size: int
pool_type: Union[str, type["BaseTokenizerGroup"]]
extra_config: dict
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
pool_size: int = 0
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
pool_type: str = "ray"
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
extra_config: dict = field(default_factory=dict)
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
def __post_init__(self) -> None:
logger.warning_once(
"TokenizerPoolConfig is deprecated and will be removed in a "
"future release. Passing this parameter will have no effect. "
"Please remove it from your configurations.")
class LoadFormat(str, enum.Enum):
......@@ -1441,6 +1484,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"
......@@ -1475,7 +1519,7 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None
model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
......@@ -1506,10 +1550,6 @@ class LoadConfig:
return hash_str
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
......@@ -1522,6 +1562,9 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config
@dataclass
class ParallelConfig:
......@@ -1536,8 +1579,21 @@ class ParallelConfig:
the product of the tensor parallel size and data parallel size."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: Optional[int] = None
"""Local rank of the data parallel group, defaults to global rank."""
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
"""Private field to store the local rank of the data parallel group."""
@property
def data_parallel_rank_local(self) -> int:
"""Local rank of the data parallel group, defaults to global rank."""
if self._data_parallel_rank_local is None:
return self.data_parallel_rank
return self._data_parallel_rank_local
@data_parallel_rank_local.setter
def data_parallel_rank_local(self, value: int) -> None:
"""Set the local rank of the data parallel group."""
self._data_parallel_rank_local = value
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_master_port: int = 29500
......@@ -1554,8 +1610,8 @@ class ParallelConfig:
"""Disable the custom all-reduce kernel and fall back to NCCL."""
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""Config for the tokenizer pool. If None, will use synchronous
tokenization."""
"""This parameter is deprecated and will be removed in a future release.
Please remove it from your configs"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
......@@ -1563,7 +1619,7 @@ class ParallelConfig:
placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[str,
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
type["ExecutorBase"]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
......@@ -1577,7 +1633,7 @@ class ParallelConfig:
"""The full name of the worker class to use. If "auto", the worker class
will be determined based on the platform."""
sd_worker_cls: str = "auto"
"""The full name of the worker class to use for speculative decofing.
"""The full name of the worker class to use for speculative decofing.
If "auto", the worker class will be determined based on the platform."""
worker_extension_cls: str = ""
"""The full name of the worker extension class to use. The worker extension
......@@ -1646,6 +1702,7 @@ class ParallelConfig:
factors: list[Any] = []
factors.append(self.pipeline_parallel_size)
factors.append(self.tensor_parallel_size)
factors.append(self.enable_expert_parallel)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:
......@@ -1687,7 +1744,7 @@ class ParallelConfig:
# current node and we aren't in a ray placement group.
from vllm.executor import ray_utils
backend = "mp"
backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available()
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
......@@ -1755,92 +1812,125 @@ class ParallelConfig:
"worker_extension_cls must be a string (qualified class name).")
PreemptionMode = Literal["swap", "recompute"]
SchedulerPolicy = Literal["fcfs", "priority"]
@config
@dataclass
class SchedulerConfig:
"""Scheduler configuration."""
runner_type: str = "generate" # The runner type to launch for the model.
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: int = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
# Maximum number of tokens to be processed in a single iteration.
max_num_batched_tokens: int = field(default=None) # type: ignore
max_num_seqs: int = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.
# Maximum number of sequences to be processed in a single iteration.
max_num_seqs: int = 128
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
# Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192
max_model_len: int = None # type: ignore
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1
"""For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump
the queue in front of longer prompts in some cases, improving latency."""
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
# accepted.
num_lookahead_slots: int = 0
"""The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
# Apply a delay (of delay factor multiplied by previous
# prompt latency) before scheduling next prompt.
delay_factor: float = 0.0
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""
# If True, prefill requests can be chunked based
# on the remaining max_num_batched_tokens.
enable_chunked_prefill: bool = False
enable_chunked_prefill: bool = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
is_multimodal_model: bool = False
"""True if the model is multimodal."""
# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False)
"""Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# NOTE: The following multimodal encoder budget will be initialized to
# max_num_batched_tokens and overridden in case max multimodal embedding
# size is larger.
# TODO (ywang96): Make these configurable.
# Multimodal encoder compute budget, only used in V1
max_num_encoder_input_tokens: int = field(default=None) # type: ignore
# TODO (ywang96): Make this configurable.
encoder_cache_size: int = field(init=False)
"""Multimodal encoder cache size, only used in V1.
# Multimodal encoder cache size, only used in V1
encoder_cache_size: int = field(default=None) # type: ignore
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
preemption_mode: Optional[str] = None
preemption_mode: Optional[PreemptionMode] = None
"""Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead."""
num_scheduler_steps: int = 1
"""Maximum number of forward steps per scheduler call."""
multi_step_stream_outputs: bool = False
multi_step_stream_outputs: bool = True
"""If False, then multi-step will stream outputs at the end of all steps"""
# Private API. If used, scheduler sends delta data to
# workers instead of an entire data. It should be enabled only
# when SPMD worker architecture is enabled. I.e.,
# VLLM_USE_RAY_SPMD_WORKER=1
send_delta_data: bool = False
# The scheduling policy to use. "fcfs" (default) or "priority".
policy: str = "fcfs"
"""Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1"""
policy: SchedulerPolicy = "fcfs"
"""The scheduling policy to use:\n
- "fcfs" means first come first served, i.e. requests are handled in order
of arrival.\n
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = field(init=False)
"""True if chunked prefill is enabled."""
# If set to true and chunked prefill is enabled, we do not want to
# partially schedule a multimodal item. Only used in V1
# This ensures that if a request has a mixed prompt
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
This ensures that if a request has a mixed prompt
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
default scheduler. Can be a class directly or the path to a class of form
"mod.custom_class"."""
def compute_hash(self) -> str:
"""
......@@ -1862,6 +1952,18 @@ class SchedulerConfig:
return hash_str
def __post_init__(self) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
logger.warning(
"max_model_len was is not set. Defaulting to arbitrary value "
"of %d.", self.max_model_len)
if self.max_num_seqs is None:
self.max_num_seqs = 128
logger.warning(
"max_num_seqs was is not set. Defaulting to arbitrary value "
"of %d.", self.max_num_seqs)
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
if self.num_scheduler_steps > 1:
......@@ -1974,9 +2076,19 @@ class SchedulerConfig:
return self.num_scheduler_steps > 1
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
@config
@dataclass
class DeviceConfig:
device: Optional[torch.device]
device_type: str
"""Configuration for the device to use for vLLM execution."""
device: Union[Device, torch.device] = "auto"
"""Device type for vLLM execution."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
......@@ -1998,8 +2110,8 @@ class DeviceConfig:
usedforsecurity=False).hexdigest()
return hash_str
def __init__(self, device: str = "auto") -> None:
if device == "auto":
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
......@@ -2010,7 +2122,7 @@ class DeviceConfig:
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
self.device_type = device
self.device_type = self.device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
......@@ -2022,139 +2134,113 @@ class DeviceConfig:
self.device = torch.device(self.device_type)
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
@config
@dataclass
class SpeculativeConfig:
"""
Configuration for speculative decoding.
Configurable parameters include:
- General Speculative Decoding Control:
- num_speculative_tokens (int): The number of speculative
tokens, if provided. It will default to the number in the draft
model config if present, otherwise, it is required.
- model (Optional[str]): The name of the draft model, eagle head,
or additional weights, if provided.
- method (Optional[str]): The name of the speculative method to use.
If users provide and set the `model` param, the speculative method
type will be detected automatically if possible, if `model` param
is not provided, the method name must be provided.
- Possible values:
- ngram
Related additional configuration:
- prompt_lookup_max (Optional[int]):
Maximum size of ngram token window when using Ngram
proposer, required when method is set to ngram.
- prompt_lookup_min (Optional[int]):
Minimum size of ngram token window when using Ngram
proposer, if provided. Defaults to 1.
- eagle
- medusa
- mlp_speculator
- draft_model
- acceptance_method (str): The method to use for accepting draft
tokens. This can take two possible values: 'rejection_sampler' and
'typical_acceptance_sampler' for RejectionSampler and
TypicalAcceptanceSampler respectively. If not specified, it
defaults to 'rejection_sampler'.
- Possible values:
- rejection_sampler
- typical_acceptance_sampler
Related additional configuration:
- posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the
posterior probability of a token in the target model
for it to be accepted. This threshold is used only
when we use the TypicalAcceptanceSampler for token
acceptance.
- posterior_alpha (Optional[float]):
Scaling factor for entropy-based threshold, applied
when using TypicalAcceptanceSampler.
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
parallelism for the draft model. Can only be 1 or the same as the
target model's tensor parallel size.
- disable_logprobs (bool): If set to True, token log probabilities are
not returned during speculative decoding. If set to False, token
log probabilities are returned according to the log probability
settings in SamplingParams. If not specified, it defaults to True.
- Draft Model Configuration:
- quantization (Optional[str]): Quantization method that was used to
quantize the draft model weights. If None, we assume the
model weights are not quantized. Note that it only takes effect
when using the draft model-based speculative method.
- max_model_len (Optional[int]): The maximum model length of the
draft model. Used when testing the ability to skip
speculation for some sequences.
- revision: The specific model version to use for the draft model. It
can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version.
- code_revision: The specific revision to use for the draft model code
on Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
"""Configuration for speculative decoding."""
- Advanced Control:
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
batch expansion for scoring proposals. If not specified, it
defaults to False.
- disable_by_batch_size (Optional[int]): Disable speculative decoding
for new incoming requests when the number of enqueued requests is
larger than this value, if provided.
Although the parameters above are structured hierarchically, there is no
need to nest them during configuration.
Non-configurable internal parameters include:
- Model Configuration:
- target_model_config (ModelConfig): The configuration of the target
model.
- draft_model_config (ModelConfig): The configuration of the draft
model initialized internal.
- Parallelism Configuration:
- target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
- draft_parallel_config (ParallelConfig): The parallel configuration
for the draft model initialized internal.
- Execution Control:
- enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since it's not
yet compatible with speculative decode.
- disable_log_stats (bool): Whether to disable the periodic printing of
stage times in speculative decoding.
"""
# speculative configs from cli args
# General speculative decoding control
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
method: Optional[str] = None
acceptance_method: str = "rejection_sampler"
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
"""The name of the draft model, eagle head, or additional weights, if
provided."""
method: Optional[SpeculativeMethod] = None
"""The name of the speculative method to use. If users provide and set the
`model` param, the speculative method type will be detected automatically
if possible, if `model` param is not provided, the method name must be
provided.
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler"
"""The method to use for accepting draft tokens:\n
- "rejection_sampler" maps to `RejectionSampler`.\n
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
If using `typical_acceptance_sampler`, the related configuration
`posterior_threshold` and `posterior_alpha` should be considered."""
draft_tensor_parallel_size: Optional[int] = None
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
"""If set to True, token log probabilities are not returned during
speculative decoding. If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams."""
model: Optional[str] = None
# Draft model configuration
quantization: Optional[str] = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: Optional[int] = None
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: Optional[str] = None
"""The specific model version to use for the draft model. It can be a
branch name, a tag name, or a commit id. If unspecified, will use the
default version."""
code_revision: Optional[str] = None
"""The specific revision to use for the draft model code on Hugging Face
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version."""
# Advanced control
disable_mqa_scorer: bool = False
"""Disable the MQA scorer and fall back to batch expansion for scoring
proposals."""
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: Optional[int] = None
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
# Typical acceptance sampler configuration
posterior_threshold: Optional[float] = None
"""A threshold value that sets a lower bound on the posterior probability
of a token in the target model for it to be accepted. This threshold is
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
"""
posterior_alpha: Optional[float] = None
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
"""The configuration of the target model."""
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: bool = field(default=None, init=True) # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""
# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
"""The parallel configuration for the draft model initialized internal."""
def compute_hash(self) -> str:
"""
......@@ -2168,9 +2254,10 @@ class SpeculativeConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# spec decode does not use `torch.compile` yet.
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
......@@ -2205,7 +2292,8 @@ class SpeculativeConfig:
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config.hf_text_config.model_type \
if self.target_model_config and \
self.target_model_config.hf_text_config.model_type \
== "deepseek_v3":
# use the draft model from the same model:
self.model = self.target_model_config.model
......@@ -2286,7 +2374,10 @@ class SpeculativeConfig:
)
# Automatically detect the method
if "eagle-" in self.draft_model_config.model.lower():
if self.method in ('eagle', 'eagle3'):
pass
elif "eagle-" in self.draft_model_config.model.lower() or \
"eagle3-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
......@@ -2297,7 +2388,7 @@ class SpeculativeConfig:
self.method = "draft_model"
# Replace hf_config for EAGLE draft_model
if self.method == "eagle":
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
......@@ -2442,7 +2533,6 @@ class SpeculativeConfig:
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
......@@ -2495,6 +2585,12 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
if self.method == "eagle3" and self.target_model_config and \
"llama" not in self.target_model_config.hf_text_config.model_type:
raise ValueError(
"Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}")
@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
......@@ -2505,6 +2601,9 @@ class SpeculativeConfig:
"""
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3")
def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
......@@ -2512,18 +2611,41 @@ class SpeculativeConfig:
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
LoRADType = Literal["auto", "float16", "bfloat16"]
@config
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
"""Configuration for LoRA."""
max_lora_rank: int = 16
"""Max LoRA rank."""
max_loras: int = 1
"""Max number of LoRAs in a single batch."""
fully_sharded_loras: bool = False
"""By default, only half of the LoRA computation is sharded with tensor
parallelism. Enabling this will use the fully sharded layers. At high
sequence length, max rank or tensor parallel size, this is likely faster.
"""
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
`max_loras`."""
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[tuple[float]] = None
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""
def compute_hash(self) -> str:
"""
......@@ -2582,25 +2704,27 @@ class LoRAConfig:
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
logger.warning("LoRA with chunked prefill is still experimental "
"and may be unstable.")
def verify_lora_support(self):
if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1:
raise ValueError(
"V1 LoRA does not support long LoRA, please use V0.")
@config
@dataclass
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
"""Configuration for PromptAdapters."""
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
......@@ -2632,20 +2756,26 @@ class PromptAdapterConfig:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@config
@dataclass
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
limit_per_prompt: dict[str, int] = field(default_factory=dict)
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt:
``{"images": 16, "videos": 2}``
"""
def compute_hash(self) -> str:
......@@ -2667,24 +2797,20 @@ class MultiModalConfig:
usedforsecurity=False).hexdigest()
return hash_str
def get_default_limit_per_prompt(self) -> int:
"""
Return the default number of input items allowed per prompt
for any modality if not specified by the user.
"""
return 999 if envs.VLLM_USE_V1 else 1
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
"""
default = self.get_default_limit_per_prompt()
return self.limit_per_prompt.get(modality, default)
return self.limit_per_prompt.get(
modality,
999 if envs.VLLM_USE_V1 else 1,
)
# TODO: Add configs to init vision tower or not.
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
......@@ -2762,12 +2888,10 @@ def _get_and_verify_dtype(
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
# Fallbacks for multi-modal models if the root config
# Fallback for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
......@@ -2783,6 +2907,13 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.info(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16
from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
......@@ -2812,6 +2943,11 @@ def _get_and_verify_dtype(
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
......@@ -2997,15 +3133,28 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
@config
@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""
"""Dataclass which contains the decoding strategy of the engine."""
# Which guided decoding algo to use.
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"
guided_decoding_backend: Union[
GuidedDecodingBackendV0,
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
is subject to change in each release."""
reasoning_backend: Optional[str] = None
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.
Required for `--enable-reasoning`."""
def compute_hash(self) -> str:
"""
......@@ -3027,17 +3176,12 @@ class DecodingConfig:
return hash_str
def __post_init__(self):
v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if envs.VLLM_USE_V1:
valid_guided_backends = v1_valid_guided_backends
valid_guided_backends = get_args(GuidedDecodingBackendV1)
else:
valid_guided_backends = v0_valid_guided_backends
valid_guided_backends = get_args(GuidedDecodingBackendV0)
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}")
......@@ -3297,11 +3441,13 @@ class CompilationConfig(BaseModel):
- enable_fusion: whether to enable the custom fusion pass.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
- enable_sequence_parallelism: whether to enable sequence parallelism.
"""
dump_graph_stages: list[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_noop: bool = True
enable_sequence_parallelism: bool = False
def uuid(self):
"""
......@@ -3310,7 +3456,8 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
"enable_sequence_parallelism"})
return InductorPass.hash_dict(dict_)
def model_post_init(self, __context: Any) -> None:
......@@ -3337,7 +3484,8 @@ class CompilationConfig(BaseModel):
compilation_time: float = PrivateAttr
# Per-model forward context
# Map from layer name to the attention cls
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context: dict[str, Any] = PrivateAttr
def compute_hash(self) -> str:
......@@ -3668,6 +3816,17 @@ class VllmConfig:
return quant_config
return None
@staticmethod
def get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
import copy
# For some reason, the _ version of this modifies the model_config
# object, so using deepcopy to avoid this problem.
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
load_config)
def with_hf_config(
self,
hf_config: PretrainedConfig,
......@@ -3697,8 +3856,6 @@ class VllmConfig:
if self.lora_config:
self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
self.lora_config.verify_lora_support()
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
......@@ -3722,6 +3879,8 @@ class VllmConfig:
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager:
# NOTE(woosuk): Currently, we use inductor because the piecewise
......@@ -3729,7 +3888,8 @@ class VllmConfig:
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
# FIXME(rob): Add function to set all of these.
self.compilation_config.custom_ops = ["none"]
if not self.compilation_config.custom_ops:
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1
......@@ -3738,6 +3898,18 @@ class VllmConfig:
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
if self.parallel_config is not None and \
self.parallel_config.tensor_parallel_size > 1 and \
self.parallel_config.pipeline_parallel_size > 1 and \
self.compilation_config is not None and \
self.compilation_config.pass_config is not None and \
self.compilation_config.pass_config.enable_sequence_parallelism:
logger.warning_once(
"Sequence parallelism is not supported with pipeline "
"parallelism. Disabling sequence parallelism.")
self.compilation_config.pass_config.\
enable_sequence_parallelism = False
self._set_cudagraph_sizes()
if self.cache_config is not None and \
......@@ -3777,6 +3949,26 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
removed_sizes = [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size != 0
]
if removed_sizes:
logger.warning(
"Batch sizes %s are removed because they are not "
"multiple of tp_size %d when "
"sequence parallelism is enabled", removed_sizes,
self.parallel_config.tensor_parallel_size)
return [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size == 0
]
def _set_cudagraph_sizes(self):
"""
cudagraph batchsize padding logic:
......@@ -3814,6 +4006,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
possible_sizes = self.update_sizes_for_sequence_parallelism(
possible_sizes)
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes = [
......@@ -3837,6 +4034,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
batch_size_capture_list = \
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list
......@@ -3935,3 +4137,43 @@ def get_current_vllm_config() -> VllmConfig:
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config
def contains_object_print(text):
"""
Check if the text looks like a printed Python object, e.g.
contains any substring matching the pattern: "at 0xFFFFFFF>"
We match against 0x followed by 2-16 hex chars (there's
a max of 16 on a 64 bit system).
Args:
text (str): The text to check
Returns:
bool: True if a match is found, False otherwise
"""
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
match = re.search(pattern, text)
return match is not None
def assert_hashable(text):
if not contains_object_print(text):
return True
raise AssertionError(
f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}")
T = TypeVar("T")
def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}
......@@ -1596,7 +1596,6 @@ class Scheduler:
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
......
......@@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Reduce-Scatter the input tensor across model parallel group."""
return get_tp_group().reduce_scatter(input_, dim)
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
......
......@@ -61,6 +61,40 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def gather(self,
input_: torch.Tensor,
dst: int = 0,
......
......@@ -70,6 +70,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
......
......@@ -7,11 +7,13 @@ import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional, Tuple, Union
from threading import Event
from typing import Any, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.distributed as dist
import zmq
from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
......@@ -239,7 +241,7 @@ class MessageQueue:
self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True
connect_ip = f"[{connect_ip}]"
socket_addr = f"tcp://*:{remote_subscribe_port}"
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
else:
......@@ -400,7 +402,9 @@ class MessageQueue:
break
@contextmanager
def acquire_read(self, timeout: Optional[float] = None):
def acquire_read(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
......@@ -430,6 +434,9 @@ class MessageQueue:
)
n_warning += 1
if cancel is not None and cancel.is_set():
raise RuntimeError("cancelled")
# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
......@@ -464,10 +471,12 @@ class MessageQueue:
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self, timeout: Optional[float] = None):
def dequeue(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read(timeout) as buf:
with self.acquire_read(timeout, cancel) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
......@@ -475,15 +484,21 @@ class MessageQueue:
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
obj = MessageQueue.recv(self.local_socket, timeout)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
obj = MessageQueue.recv(self.remote_socket, timeout)
else:
raise RuntimeError("Only readers can dequeue")
return obj
@staticmethod
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms):
raise TimeoutError
recv = socket.recv(copy=False)
return pickle.loads(recv.buffer)
def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
......
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, get_kv_transfer_group,
has_kv_transfer_group, is_v1_kv_transfer_group)
__all__ = [
"get_kv_transfer_group", "has_kv_transfer_group",
"is_v1_kv_transfer_group", "ensure_kv_transfer_initialized",
"KVConnectorBaseType"
]
......@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
......@@ -121,3 +122,6 @@ class KVConnectorBase(ABC):
"""
raise NotImplementedError
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
......@@ -3,14 +3,22 @@
import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.logger import init_logger
from .base import KVConnectorBase
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
_registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str,
......@@ -19,22 +27,51 @@ class KVConnectorFactory:
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBase]:
def loader() -> Type[KVConnectorBaseType]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
def create_connector_v0(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
if envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V0 Connector, "
f"but found {envs.VLLM_USE_V1=}")
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, KVConnectorBase)
return connector_cls(rank, local_rank, config)
@classmethod
def create_connector_v1(
cls,
config: "VllmConfig",
role: KVConnectorRole,
) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}")
connector_name = config.kv_transfer_config.kv_connector
connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s", connector_name)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
# - Should only be used inside the Scheduler class
# Worker connector:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
......@@ -57,4 +94,14 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector(
"MooncakeStoreConnector",
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
"MooncakeStoreConnector")
\ No newline at end of file
"MooncakeStoreConnector")
KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector")
KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1")
# SPDX-License-Identifier: Apache-2.0
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
......@@ -11,9 +10,10 @@ from typing import TYPE_CHECKING, List, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
......@@ -32,8 +32,7 @@ class MooncakeStoreConnector(KVConnectorBase):
config: VllmConfig,
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.kv_helper = kv_helper(config)
self.local_tp_rank = local_rank
# Init kv_store
......@@ -80,12 +79,7 @@ class MooncakeStoreConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
......@@ -97,10 +91,8 @@ class MooncakeStoreConnector(KVConnectorBase):
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
kv_cache, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
......@@ -173,22 +165,15 @@ class MooncakeStoreConnector(KVConnectorBase):
layer = model_executable.model.layers[layer_id]
# get kvcache object
kv_cache = kv_caches[layer_id - start_layer]
key_cache, value_cache = kv_cache[0], kv_cache[1]
# get remote kvcache
# get remote kvcache
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
layer_id]
# use ops.reshape_and_cache_flash to put kv into kvcache
ops.reshape_and_cache_flash(
remote_k.to(key_cache.device),
remote_v.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_v, layer, kv_cache,
slot_mapping, start_pos,
end_pos)
hidden_or_intermediate_states_for_one_req.append(hidden)
......
......@@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
......@@ -37,9 +37,7 @@ class SimpleConnector(KVConnectorBase):
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.kv_helper = kv_helper(config)
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
......@@ -165,31 +163,7 @@ class SimpleConnector(KVConnectorBase):
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
......@@ -212,13 +186,8 @@ class SimpleConnector(KVConnectorBase):
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
kv_cache, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
......@@ -248,12 +217,12 @@ class SimpleConnector(KVConnectorBase):
# and hidden states.
bypass_model_exec = True
model_config = model_executable.model.config
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
hidden_or_intermediate_states_for_one_req = []
......@@ -312,41 +281,19 @@ class SimpleConnector(KVConnectorBase):
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys[
i - model_executable.model.start_layer].to(
kv_cache.device).squeeze(1)
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
ops.concat_and_cache_mla(
k_c_normed,
k_pe,
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
for cur_layer in range(start_layer, end_layer):
layer_id = cur_layer - start_layer
kv_cache = kv_caches[layer_id]
layer = model_executable.model.layers[cur_layer]
# get remote kvcache
remote_k, remote_v = keys[layer_id], values[layer_id]
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
remote_v, layer, kv_cache,
slot_mapping, start_pos,
end_pos)
hidden_or_intermediate_states_for_one_req.append(hidden)
......
# SPDX-License-Identifier: Apache-2.0
"""
KV cache helper for store.
"""
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class model_aware_kv_ops_helper:
def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size
def get_model_args(self, model_executable: torch.nn.Module):
model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
return num_heads, head_size
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
layer, kv_cache, slot_mapping, start_pos, end_pos):
model_config = model_executable.model.config
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
]
# SPDX-License-Identifier: Apache-2.0
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done
"""
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
# Connector running in the worker process
WORKER = 1
@dataclass
class KVConnectorMetadata:
pass
class KVConnectorBase_V1(ABC):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata()
self._vllm_config = vllm_config
self._role = role
@property
def role(self) -> KVConnectorRole:
return self._role
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
KV cache loading and saving.
Args:
connector_metadata (dict): the connector metadata.
"""
self._connector_metadata = connector_metadata
def clear_connector_metadata(self) -> None:
"""Clear the connector metadata.
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = KVConnectorMetadata()
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
return self._connector_metadata
# ==============================
# Worker-side methods
# ==============================
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
pass
@abstractmethod
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
pass
@abstractmethod
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
pass
@abstractmethod
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
pass
# ==============================
# Scheduler-side methods
# ==============================
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
pass
@abstractmethod
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self._lmcache_engine.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
**kwargs)
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self._lmcache_engine.wait_for_save()
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
self._lmcache_engine.update_state_after_alloc(request,
num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)
# SPDX-License-Identifier: Apache-2.0
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
@staticmethod
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
is_store: bool) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
transfer_config = vllm_config.kv_transfer_config
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(vllm_config.kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning(
"In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info("Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping))
for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache_layer = attn_layer.kv_cache[\
forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids)
kv_cache = safetensors.torch.load_file(
filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids)
kv_cache = extract_kv_from_layer(kv_layer,
request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_size=self._block_size,
is_store=False)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_size=self._block_size,
is_store=True)
for cached_req in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption:
break
if cached_req.req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[cached_req.req_id]
total_tokens = (len(cached_req.new_token_ids) +
cached_req.num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request.
"""
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(torch.tensor(
request.prompt_token_ids)[:num_tokens_to_check],
create_folder=False)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
input_ids: torch.Tensor,
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
input_ids_bytes = input_ids.numpy().tobytes()
input_ids_hash = hashlib.md5(input_ids_bytes,
usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
input_ids: torch.Tensor,
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(input_ids,
create_folder=True)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size.
"""
return (num_tokens - 1) // block_size * block_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment