Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
......@@ -23,6 +23,10 @@ class CompilationCounter:
num_inductor_compiles: int = 0
# EagerAdapter.compile calls
num_eager_compiles: int = 0
# The number of time vLLM's compiler cache entry was updated
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)
......
......@@ -20,9 +20,38 @@ from .monitor import start_monitoring_torch_compile
logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=type[nn.Module])
def ignore_torch_compile(cls: _T) -> _T:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
a support_torch_compile decorator, but we don't want to
compile the class `cls` that inherits the parent class.
This only ignores compiling the forward of the class the
decorator is applied to.
If the parent has ignore_torch_compile but the child has
support_torch_compile, the child will still be compiled.
If the class has one or more submodules
that have support_torch_compile decorator applied, compile will
not be ignored for those submodules.
"""
setattr(cls, IGNORE_COMPILE_KEY, True)
return cls
def _should_ignore_torch_compile(cls) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
return getattr(cls, IGNORE_COMPILE_KEY, False)
@overload
def support_torch_compile(
*,
......@@ -148,6 +177,8 @@ def _support_torch_compile(
old_init = cls.__init__
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
......@@ -156,9 +187,11 @@ def _support_torch_compile(
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
] or not supports_dynamo() or _should_ignore_torch_compile(
self.__class__)
if self.do_not_compile:
return
compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional
import torch
import torch._inductor.pattern_matcher as pm
......@@ -11,6 +11,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
......@@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
......
......@@ -7,7 +7,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
# from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
......@@ -62,7 +62,8 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):
......
......@@ -6,13 +6,7 @@ import time
import torch
from torch._dynamo.utils import lazy_format_graph_code
from vllm.config import PassConfig, VllmConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
from vllm.distributed import model_parallel_is_initialized as p_is_init
# yapf: enable
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .inductor_pass import InductorPass
......@@ -34,22 +28,9 @@ class VllmInductorPass(InductorPass):
else None
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
def dump_graph(self, graph: torch.fx.Graph, stage: str):
lazy_format_graph_code(stage, graph.owning_module)
if stage in self.pass_config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py"
logger.info("%s printing graph to %s", self.pass_name, filepath)
with open(filepath, "w") as f:
src = graph.python_code(root_module="self", verbose=True).src
# Add imports so it's not full of errors
print("import torch; from torch import device", file=f)
print(src, file=f)
def begin(self):
self._start_time = time.perf_counter_ns()
......@@ -61,10 +42,9 @@ class VllmInductorPass(InductorPass):
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: PassConfig, always=False):
def __init__(self, name: str, config: VllmConfig):
super().__init__(config)
self.name = name
self.always = always
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name, always=self.always)
self.dump_graph(graph, self.name)
......@@ -93,9 +93,10 @@ class TorchCompileWrapperWithCustomDispatcher:
return
self.compiled_codes.append(new_code)
local_cache_dir = self.vllm_config.compilation_config.local_cache_dir
if isinstance(local_cache_dir, str):
decompiled_file = os.path.join(local_cache_dir,
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
if isinstance(debug_dump_dir, str) and debug_dump_dir != "":
rank = self.vllm_config.parallel_config.rank
decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}",
"transformed_code.py")
if not os.path.exists(decompiled_file):
try:
......@@ -105,6 +106,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# not a reversible process.
import depyf
src = depyf.decompile(new_code)
with open(decompiled_file, "w") as f:
f.write(src)
......
......@@ -17,9 +17,9 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, List,
Protocol, TypeVar, Union, cast, get_args, get_origin)
Protocol, TypeVar, Union, cast, get_args)
import regex as re
import torch
......@@ -28,12 +28,13 @@ from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self, deprecated, runtime_checkable
from typing_extensions import Self, runtime_checkable
import vllm.envs as envs
from vllm import version
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
......@@ -74,6 +75,7 @@ if TYPE_CHECKING:
ConfigType = type[DataclassInstance]
HfOverrides = Union[dict, Callable[[type], type]]
else:
DataclassInstance = Any
PlacementGroup = Any
PretrainedConfig = Any
ExecutorBase = Any
......@@ -93,27 +95,23 @@ logger = init_logger(__name__)
models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH")
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]
"score", "reward", "transcription", "draft"]
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
"transcription"]
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
"classify", "reward", "draft"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
RunnerType = Literal["generate", "pooling", "draft"]
_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = {
task: runner
for runner, tasks in _RUNNER_TASKS.items()
for task in tasks
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate", "transcription"],
"pooling": ["encode", "embed", "classify", "reward"],
"draft": [],
}
......@@ -228,22 +226,27 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "cpm", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
"""Configuration for the model."""
model: str = os.path.join(models_path_prefix, 'facebook/opt-125m') if models_path_prefix is not None else 'facebook/opt-125m'
model: str = os.path.join(models_path_prefix, "Qwen/Qwen3-0.6B") if models_path_prefix is not None else "Qwen/Qwen3-0.6B"
"""Name or path of the Hugging Face model to use. It is also used as the
content for `model_name` tag in metrics output when `served_model_name` is
not specified."""
task: Literal[TaskOption, Literal["draft"]] = "auto"
"""The task to use the model for. Each vLLM instance only supports one
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
runner: RunnerOption = "auto"
"""The type of model runner to use. Each vLLM instance only supports one
model runner, even if the same model can be used for multiple types."""
task: TaskOption = "auto"
"""The task to use the model for. If the model supports more than one
model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner."""
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
......@@ -322,6 +325,13 @@ class ModelConfig:
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
"""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
......@@ -351,9 +361,12 @@ class ModelConfig:
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
"""Maximum number of data items per modality per prompt. Only applicable
for multimodal models."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string. Defaults to False."""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
use_async_output_proc: bool = True
"""Whether to use async output processor."""
......@@ -541,16 +554,12 @@ class ModelConfig:
self.config_format = ConfigFormat(self.config_format)
hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)
if hf_overrides_kw:
logger.debug("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.debug("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)
self.trust_remote_code,
self.revision,
self.code_revision,
self.config_format,
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn)
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
......@@ -560,18 +569,49 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision)
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
# For pooling models, self.task is used to indicate the
# user-selected task
if self.task == "score":
if self._is_classify_task(self.architectures):
self.task = "classify"
else:
self.task = "embed"
elif self.task == "embedding":
msg = ("The 'embedding' task has been renamed to 'embed', please "
"use the new name. The old name will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.task = "embed"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
all_supported_tasks = self._get_supported_tasks(self.task)
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
supported_runner_types = self._get_supported_runner_types(
all_supported_tasks)
runner_type = self._resolve_runner(self.runner, self.task,
supported_runner_types,
all_supported_tasks)
logger.debug("Selected runner type: %s", runner_type)
# For pooling models, self.task is used to indicate the
# user-selected task
if runner_type == "pooling" and self.task == "auto":
selected_task = all_supported_tasks[runner_type][-1]
assert selected_task != "encode"
self.task = selected_task
self.supported_runner_types = supported_runner_types
self.runner_type = runner_type
self.supported_tasks = all_supported_tasks[runner_type]
if self.runner_type in ("draft",
"generate") and self.task != "transcription":
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
......@@ -623,6 +663,8 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self.registry.supports_multimodal_raw_input(self.architectures))
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
......@@ -655,6 +697,16 @@ class ModelConfig:
"max_model_len must be an integer after __post_init__.")
return self
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
else:
return "TransformersForCausalLM"
@property
def registry(self):
return me_models.ModelRegistry
......@@ -662,7 +714,19 @@ class ModelConfig:
@property
def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", [])
architectures = getattr(self.hf_config, "architectures", [])
# The registry assumes that it can always inspect the vLLM model class
# for a given architecture. This assumption breaks down for the
# Transformers backend, which may use a different class depending on
# the model type. To work around this, we add the correct Transformers
# backend class to the architectures list. We must do this here because
# we need access to the `hf_config` to determine the backend class.
transformers_backend_cls = self._get_transformers_backend_cls()
if (self.model_impl != ModelImpl.VLLM.value
and all(arch != transformers_backend_cls
for arch in architectures)):
architectures.append(transformers_backend_cls)
return architectures
@property
def architecture(self) -> str:
......@@ -693,8 +757,11 @@ class ModelConfig:
# If tokenizer is same as model, download to same directory
if model == tokenizer:
s3_model.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
s3_model.pull_files(model,
ignore_pattern=[
"*.pt", "*.safetensors", "*.bin",
"*.tensors"
])
self.tokenizer = s3_model.dir
return
......@@ -702,7 +769,8 @@ class ModelConfig:
if is_s3(tokenizer):
s3_tokenizer = S3Model()
s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model,
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
......@@ -712,7 +780,8 @@ class ModelConfig:
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self.
disable_mm_preprocessor_cache)
disable_mm_preprocessor_cache,
interleave_mm_strings=self.interleave_mm_strings)
if self.limit_mm_per_prompt:
raise ValueError("`limit_mm_per_prompt` is only supported for "
......@@ -723,6 +792,9 @@ class ModelConfig:
if self.disable_mm_preprocessor_cache:
raise ValueError("`disable_mm_preprocessor_cache` is only "
"supported for multimodal models.")
if self.interleave_mm_strings:
raise ValueError("`interleave_mm_strings` is only "
"supported for multimodal models.")
return None
......@@ -779,107 +851,155 @@ class ModelConfig:
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
def _is_classify_task(self, architectures: list[str]):
for arch in architectures:
if arch.endswith("ForSequenceClassification"):
return True
return self.registry.is_cross_encoder_model(architectures)
def _get_preferred_pooling_task(
self,
architectures: list[str],
supported_tasks: set[_ResolvedTask],
) -> Optional[_ResolvedTask]:
) -> _ResolvedTask:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ForSequenceClassification", "classify"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
_, arch = self.registry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
if self.architecture.endswith(suffix):
return pref_task
return None
return "embed"
def _resolve_task(
def _get_supported_generation_tasks(
self,
task_option: Literal[TaskOption, Literal["draft"]],
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
task_option: TaskOption,
) -> list[_ResolvedTask]:
registry = self.registry
architectures = self.architectures
if registry.is_transcription_only_model(architectures):
return ["transcription"]
supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures):
supported_tasks.append("generate")
if registry.is_transcription_model(architectures):
supported_tasks.append("transcription")
return supported_tasks
def _get_supported_pooling_tasks(
self,
task_option: TaskOption,
) -> list[_ResolvedTask]:
registry = self.registry
architectures = self.architectures
runner_support: dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"transcription": registry.is_transcription_model(architectures),
"generate": registry.is_text_generation_model(architectures),
"pooling": registry.is_pooling_model(architectures),
}
supported_runner_types_lst: list[RunnerType] = [
runner_type
for runner_type, is_supported in runner_support.items()
if is_supported
]
supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures):
supported_tasks.append("encode")
supported_tasks_lst: list[_ResolvedTask] = [
task for runner_type in supported_runner_types_lst
for task in _RUNNER_TASKS[runner_type]
]
supported_tasks = set(supported_tasks_lst)
# For now, users must specify the task (other than "pooling")
# to use for pooling models
if task_option == "auto":
preferred_task = self._get_preferred_pooling_task(
architectures)
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
supported_tasks.append(preferred_task)
elif task_option in _RUNNER_TASKS["pooling"]:
supported_tasks.append(cast(_ResolvedTask, task_option))
if len(supported_tasks_lst) > 1:
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task is not None:
selected_task = preferred_task
return supported_tasks
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
def _get_supported_tasks(
self,
task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]:
if self._is_classify_task(self.architectures):
return {"generate": [], "pooling": ["classify"], "draft": []}
else:
if task_option == "score":
if not runner_support["pooling"]:
msg = (f"This model does not support the '{task_option}' "
f"task. Supported tasks: {supported_tasks}")
raise ValueError(msg)
if self.registry.is_cross_encoder_model(architectures):
task_option = "classify"
else:
task_option = "embed"
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
}
def _get_supported_runner_types(
self,
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
) -> set[RunnerType]:
return {
runner
for runner, runner_tasks in supported_tasks.items()
if len(runner_tasks) > 0
}
def _resolve_runner(
self,
runner_option: RunnerOption,
task_option: TaskOption,
supported_runner_types: set[RunnerType],
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
) -> RunnerType:
if not supported_runner_types:
raise ValueError("This model does not support any model runners!")
if runner_option != "auto":
if runner_option not in supported_runner_types:
raise ValueError(
f"This model does not support runner={runner_option!r}. "
f"Available runners: {supported_runner_types}")
return runner_option
if task_option != "auto":
for runner, runner_tasks in supported_tasks.items():
if task_option in runner_tasks:
return runner
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_runner: RunnerType = next(
runner for runner, tasks in _RUNNER_TASKS.items()
if task_option in tasks)
raise ValueError(
f"This model does not support task={task_option!r}. "
f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}")
task_option = "embed"
if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
for suffix, pref_runner in suffix_to_preferred_runner:
if self.architecture.endswith(
suffix) and pref_runner in supported_runner_types:
return pref_runner
selected_task = task_option
if "generate" in supported_runner_types:
return "generate"
if "pooling" in supported_runner_types:
return "pooling"
return supported_tasks, selected_task
raise AssertionError("This line should not be reached")
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
......@@ -893,7 +1013,7 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
]
if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods,
......@@ -903,9 +1023,13 @@ class ModelConfig:
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
# Use the community standard 'quant_method'
quant_method = quant_cfg.get("quant_method", "").lower()
# Normalize library names
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
......@@ -920,6 +1044,8 @@ class ModelConfig:
"awq_marlin",
"ipex",
"moe_wna16",
"modelopt",
"modelopt_fp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
......@@ -1134,17 +1260,17 @@ class ModelConfig:
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size
return getattr(self.hf_text_config, "vocab_size", 0)
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
@property
def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
......@@ -1333,6 +1459,17 @@ class ModelConfig:
return sum(t == 1 for t in attn_type_list[start:end])
def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
"""
# used by e.g. Bamba, FalconH1, Granite, PLaMo2
chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None)
if chunk_size is None:
# used by e.g. Mamba2, NemotronH, Zamba
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
return chunk_size
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.
......@@ -1441,14 +1578,6 @@ class ModelConfig:
def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE and SUPPORT_TC
@property
def supported_runner_types(self) -> set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
@property
def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
......@@ -1456,21 +1585,25 @@ class ModelConfig:
@property
def is_matryoshka(self) -> bool:
return (hasattr(self.hf_config, "matryoshka_dimensions")
return (bool(getattr(self.hf_config, "matryoshka_dimensions", None))
or getattr(self.hf_config, "is_matryoshka", False))
@property
def matryoshka_dimensions(self):
return getattr(self.hf_config, "matryoshka_dimensions", None)
@property
def use_pad_token(self) -> bool:
# cross_encoder models defaults to using pad_token.
# `llm as reranker` models defaults to not using pad_token.
return getattr(self.hf_config, "use_pad_token", True)
def get_and_verify_max_len(self, max_model_len: int):
# For pooling models, the tokenizer's `model_max_length` is often a
# reliable source for the maximum sequence length. However, for
# generative models, this can be incorrect and unduly limit the
# context window (e.g., DeepSeek-R1). Therefore, we only consider
# tokenizer_config for pooling models.
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
tokenizer_config = None
if self.runner_type == "pooling":
if (self.runner_type == "pooling" and getattr(
self.hf_config, "position_embedding_type", "") == "absolute"):
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,
......@@ -1488,8 +1621,8 @@ class ModelConfig:
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "int8"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc", "int8"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
@config
......@@ -1518,7 +1651,7 @@ class CacheConfig:
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3)."""
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
......@@ -1534,7 +1667,12 @@ class CacheConfig:
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads."""
- "sha256" is collision resistant but with certain overheads.
This option uses Pickle for object serialization before hashing.\n
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
hash. It serializes objects using canonical CBOR and hashes them with
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
digest."""
cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
......@@ -1550,6 +1688,9 @@ class CacheConfig:
checkpoint if available. Otherwise, the scales will default to 1.0."""
cpu_kvcache_space_bytes: Optional[int] = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: Optional[int] = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
......@@ -1608,7 +1749,7 @@ class CacheConfig:
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
"scaling factor.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
......@@ -1647,35 +1788,6 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
@config
@dataclass
class TokenizerPoolConfig:
"""This config is deprecated and will be removed in a future release.
Passing these parameters will have no effect. Please remove them from your
configurations.
"""
pool_size: int = 0
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
pool_type: str = "ray"
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
extra_config: dict = field(default_factory=dict)
"""This parameter is deprecated and will be removed in a future release.
Passing this parameter will have no effect. Please remove it from your
configurations."""
def __post_init__(self) -> None:
logger.warning_once(
"TokenizerPoolConfig is deprecated and will be removed in a "
"future release. Passing this parameter will have no effect. "
"Please remove it from your configurations.")
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
......@@ -1727,6 +1839,9 @@ class LoadConfig:
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format."""
device: Optional[str] = None
"""Device to which model weights will be loaded, default to
device_config.device"""
ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
......@@ -1808,8 +1923,16 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Set implicitly when
data_parallel_rank is provided explicitly to vllm serve."""
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
is provided explicitly to vllm serve."""
data_parallel_hybrid_lb: bool = False
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
......@@ -1839,10 +1962,6 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
"""This parameter is deprecated and will be removed in a future release.
Please remove it from your configs"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
......@@ -1857,7 +1976,7 @@ class ParallelConfig:
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference."""
only support Ray for distributed inference."""
worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
......@@ -1951,6 +2070,19 @@ class ParallelConfig:
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished
@staticmethod
def sync_kv_cache_memory_size(dp_group: "ProcessGroup",
kv_cache_memory: int) -> int:
if kv_cache_memory == -1:
kv_cache_memory = torch.iinfo(torch.int64).max
tensor = torch.tensor([kv_cache_memory],
dtype=torch.int64,
device="cpu")
# we cannot use broadcast for stateless dp group since it depends
# on global rank
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
return tensor.item()
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
......@@ -2010,6 +2142,15 @@ class ParallelConfig:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.")
if not self.enable_expert_parallel:
raise ValueError(
"enable_expert_parallel must be True to use EPLB.")
if self.tensor_parallel_size * self.data_parallel_size <= 1:
raise ValueError(
"EPLB requires tensor_parallel_size or data_parallel_size "
f"to be greater than 1, but got "
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.num_redundant_experts != 0:
raise ValueError(
......@@ -2030,10 +2171,11 @@ class ParallelConfig:
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:
raise ValueError("Unable to load Ray which is "
raise ValueError("Unable to load Ray: "
f"{ray_utils.ray_import_err}. Ray is "
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`.") from ray_utils.ray_import_err
"ray`.")
backend = "ray"
elif self.data_parallel_backend == "ray":
logger.info("Using ray distributed inference because "
......@@ -2144,11 +2286,12 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
cuda_graph_sizes: list[int] = field(default_factory=lambda: [512])
"""Cuda graph capture sizes, default is 512.
1. if one value is provided, then the capture list would follow the
cuda_graph_sizes: list[int] = field(default_factory=list)
"""Cuda graph capture sizes
1. if none provided, then default set to [min(max_num_seqs * 2, 512)]
2. if one value is provided, then the capture list would follow the
pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
2. more than one value (e.g. 1 2 128) is provided, then the capture list
3. more than one value (e.g. 1 2 128) is provided, then the capture list
will follow the provided list."""
delay_factor: float = 0.0
......@@ -2227,6 +2370,13 @@ class SchedulerConfig:
like full attention and sliding window attention.
"""
async_scheduling: bool = False
"""EXPERIMENTAL: If set to True, perform async scheduling. This may help
reduce the CPU overheads, leading to better latency and throughput. However,
async scheduling is currently not supported with some features such as
structured outputs, speculative decoding, and pipeline parallelism.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -2313,6 +2463,17 @@ class SchedulerConfig:
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
# NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
# This avoids OOM in tight memory scenarios with small max_num_seqs,
# and prevents capture of many large graphs (>512) that would greatly
# increase startup time with limited performance benefit.
if not self.cuda_graph_sizes:
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
if self.async_scheduling:
self.scheduler_cls = (
"vllm.v1.core.sched.async_scheduler.AsyncScheduler")
@model_validator(mode='after')
def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len
......@@ -2380,7 +2541,7 @@ class SchedulerConfig:
return self.num_scheduler_steps > 1
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"]
@config
......@@ -2447,8 +2608,6 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
@config
......@@ -2471,13 +2630,6 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler"
"""The method to use for accepting draft tokens:\n
- "rejection_sampler" maps to `RejectionSampler`.\n
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.
If using `typical_acceptance_sampler`, the related configuration
`posterior_threshold` and `posterior_alpha` should be considered."""
draft_tensor_parallel_size: Optional[int] = None
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
......@@ -2504,9 +2656,6 @@ class SpeculativeConfig:
will use the default version."""
# Advanced control
disable_mqa_scorer: bool = False
"""Disable the MQA scorer and fall back to batch expansion for scoring
proposals."""
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
......@@ -2519,16 +2668,6 @@ class SpeculativeConfig:
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
# Typical acceptance sampler configuration
posterior_threshold: Optional[float] = None
"""A threshold value that sets a lower bound on the posterior probability
of a token in the target model for it to be accepted. This threshold is
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
"""
posterior_alpha: Optional[float] = None
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""
speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation.
"""
......@@ -2596,7 +2735,7 @@ class SpeculativeConfig:
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
......@@ -2680,7 +2819,7 @@ class SpeculativeConfig:
if self.model is not None:
self.draft_model_config = ModelConfig(
model=self.model,
task="draft",
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.
......@@ -2714,8 +2853,8 @@ class SpeculativeConfig:
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type ==
"deepseek_mtp", "glm4_moe_mtp"):
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
......@@ -2725,6 +2864,11 @@ class SpeculativeConfig:
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
......@@ -2787,12 +2931,6 @@ class SpeculativeConfig:
self.target_parallel_config,
self.draft_tensor_parallel_size))
if self.acceptance_method == "typical_acceptance_sampler":
if self.posterior_threshold is None:
self.posterior_threshold = 0.09
if self.posterior_alpha is None:
self.posterior_alpha = 0.3
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
......@@ -2898,30 +3036,6 @@ class SpeculativeConfig:
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)
# Validate and set draft token acceptance related settings.
if self.acceptance_method is None:
raise ValueError("acceptance_method is not set. "
"Expected values are rejection_sampler or "
"typical_acceptance_sampler.")
if (self.acceptance_method != 'rejection_sampler'
and self.acceptance_method != 'typical_acceptance_sampler'):
raise ValueError(
"Expected acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.acceptance_method}")
if self.acceptance_method == "typical_acceptance_sampler" and (
(self.posterior_threshold is not None
and self.posterior_threshold < 0) or
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
raise ValueError(
"Expected the posterior_threshold and posterior_alpha of "
"typical_acceptance_sampler to be > 0. "
"Instead found posterior_threshold = "
f"{self.posterior_threshold} and posterior_alpha = "
f"{self.posterior_alpha}")
if (self.disable_by_batch_size is not None
and self.disable_by_batch_size < 2):
......@@ -2988,12 +3102,17 @@ class LoRAConfig:
(added to the base model vocabulary)."""
lora_vocab_padding_size: ClassVar[int] = current_platform\
.get_lora_vocab_padding_size()
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
default_mm_loras: Optional[dict[str, str]] = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
model always expects a LoRA to be active when a given modality is present.
Note that currently, if a request provides multiple additional
modalities, each of which have their own LoRA, we do NOT apply
default_mm_loras because we currently only support one lora adapter
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""
......@@ -3016,7 +3135,6 @@ class LoRAConfig:
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.lora_vocab_padding_size)
factors.append(self.long_lora_scaling_factors)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
......@@ -3055,64 +3173,6 @@ class LoRAConfig:
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
def verify_lora_support(self):
if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1:
raise ValueError(
"V1 LoRA does not support long LoRA, please use V0.")
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@config
@dataclass
......@@ -3130,8 +3190,8 @@ class MultiModalConfig:
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
mm_processor_kwargs: Optional[dict[str, object]] = None
......@@ -3150,6 +3210,11 @@ class MultiModalConfig:
If `True`, disable caching of the processed multi-modal inputs.
"""
interleave_mm_strings: bool = False
"""
Enable fully interleaved support for multimodal prompts.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -3580,7 +3645,8 @@ def get_served_model_name(model: str,
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1]
......@@ -3590,18 +3656,6 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine."""
@property
@deprecated(
"`guided_decoding_backend` is deprecated and has been renamed to "
"`backend`. This will be removed in v0.10.0. Please use the "
"`backend` argument instead.")
def guided_decoding_backend(self) -> GuidedDecodingBackend:
return self.backend
@guided_decoding_backend.setter
def guided_decoding_backend(self, value: GuidedDecodingBackend):
self.backend = value
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
......@@ -3644,9 +3698,6 @@ class DecodingConfig:
return hash_str
def __post_init__(self):
if ":" in self.backend:
self._extract_backend_options()
if envs.VLLM_USE_V1:
valid_guided_backends = get_args(GuidedDecodingBackendV1)
else:
......@@ -3662,24 +3713,6 @@ class DecodingConfig:
raise ValueError("disable_additional_properties is only supported "
"for the guidance backend.")
@deprecated(
"Passing guided decoding backend options inside backend in the format "
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
"use the dedicated arguments '--disable-fallback', "
"'--disable-any-whitespace' and '--disable-additional-properties' "
"instead.")
def _extract_backend_options(self):
"""Extract backend options from the backend string."""
backend, options = self.backend.split(":")
self.backend = cast(GuidedDecodingBackend, backend)
options_set = set(options.strip().split(","))
if "no-fallback" in options_set:
self.disable_fallback = True
if "disable-any-whitespace" in options_set:
self.disable_any_whitespace = True
if "no-additional-properties" in options_set:
self.disable_additional_properties = True
DetailedTraceModules = Literal["model", "worker", "all"]
......@@ -3930,11 +3963,6 @@ class PassConfig:
don't all have access to full configuration - that would create a cycle as
the `PassManager` is set as a property of config."""
dump_graph_stages: list[str] = field(default_factory=list)
"""List of stages for which we want to dump the graph. Each pass defines
its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs."""
enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
......@@ -3945,6 +3973,10 @@ class PassConfig:
"""Whether to enable sequence parallelism."""
enable_async_tp: bool = False
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_token_num: int = 1024
"""Max number of tokens to used in flashinfer allreduce fusion."""
# TODO(luka) better pass enabling system.
......@@ -3952,12 +3984,9 @@ class PassConfig:
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Do not include dump_graph_* in the hash - they don't affect
compilation.
Any future fields that don't affect compilation should be excluded.
"""
exclude = {"dump_graph_stages", "dump_graph_dir"}
dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
return InductorPass.hash_dict(dict_)
return InductorPass.hash_dict(asdict(self))
def __post_init__(self) -> None:
if not self.enable_noop:
......@@ -4062,7 +4091,7 @@ class CompilationConfig:
- True: inductor compilation is used (custom_ops disabled by default).
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
......@@ -4318,6 +4347,7 @@ class CompilationConfig:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
......@@ -4350,8 +4380,6 @@ class VllmConfig:
"""Decoding configuration."""
observability_config: Optional[ObservabilityConfig] = None
"""Observability configuration."""
prompt_adapter_config: Optional[PromptAdapterConfig] = None
"""Prompt adapter configuration."""
quant_config: Optional[QuantizationConfig] = None
"""Quantization configuration."""
compilation_config: CompilationConfig = field(
......@@ -4360,7 +4388,7 @@ class VllmConfig:
As a shorthand, `-O<n>` can be used to directly specify the compilation
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
Currently, -O <n> and -O=<n> are supported as well but this will likely be
Currently, -O <n> and -O=<n> are supported as well but this will likely be
removed in favor of clearer -O<n> syntax in the future.
NOTE: level 0 is the default level without any optimization. level 1 and 2
......@@ -4448,10 +4476,6 @@ class VllmConfig:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
......@@ -4559,10 +4583,6 @@ class VllmConfig:
if self.lora_config is not None:
self.lora_config.verify_with_cache_config(self.cache_config)
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_lora_support()
if self.prompt_adapter_config is not None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
......@@ -4670,6 +4690,13 @@ class VllmConfig:
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.model_config is not None and \
self.model_config.attention_chunk_size is not None and \
self.speculative_config is not None and \
self.speculative_config.use_eagle():
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
......@@ -4722,7 +4749,6 @@ class VllmConfig:
# calculate the default `batch_size_capture_list`
if not envs.VLLM_USE_V1:
batch_size_capture_list = []
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
......@@ -4791,11 +4817,15 @@ class VllmConfig:
if architecture is None:
return
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)
if self.model_config.is_hybrid:
HybridAttentionMambaModelConfig.verify_and_update_config(self)
if self.model_config.task == "classify":
# Maybe convert ForCausalLM into ForSequenceClassification model.
from vllm.model_executor.models.adapters import (
......@@ -4944,3 +4974,52 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: Optional[int] = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None
def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
......@@ -7,7 +7,6 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.platforms import current_platform
from vllm.utils import Device
......@@ -56,8 +55,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
# For HPU, block id 0 is used only for padding
reserved_blocks = 1 if current_platform.is_hpu() else 0
reserved_blocks = 0
block_ids = list(
range(reserved_blocks, num_gpu_blocks + num_cpu_blocks))
num_gpu_blocks -= reserved_blocks
......
......@@ -15,7 +15,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage,
......@@ -165,8 +164,6 @@ class SchedulerOutputs:
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
......@@ -194,14 +191,6 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
......@@ -1648,7 +1637,6 @@ class Scheduler:
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# When SPMD mode is enabled, we only send delta data except for
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
from typing import Optional, Union
from weakref import WeakValueDictionary
import torch
......@@ -138,6 +138,14 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
......@@ -172,6 +180,12 @@ class DeviceCommunicatorBase:
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
raise NotImplementedError
def gather(self,
input_: torch.Tensor,
dst: int = 0,
......@@ -240,8 +254,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)
module.quant_method.init_prepare_finalize(module.moe_config)
def dispatch(
self, hidden_states: torch.Tensor,
......
......@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
from typing import Any, Optional, Union
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.utils import pickle
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -26,7 +27,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
if (current_platform.get_cpu_architecture()
== CpuArchEnum.X86) and hasattr(
torch.ops._C,
"init_shm_manager") and unique_name.startswith("tp"):
"init_shm_manager") and (unique_name.startswith("tp")
or unique_name.startswith("pp")):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(self, input_):
......@@ -94,6 +96,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
input_size[dim + 1:])
return output_tensor
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
return self.dist_module.send_tensor_dict(tensor_dict, dst)
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
return self.dist_module.recv_tensor_dict(src)
class _CPUSHMDistributed:
......@@ -143,3 +158,44 @@ class _CPUSHMDistributed:
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
key_list = list(tensor_dict.keys())
value_list = list(tensor_dict.values())
size_list = []
for v in value_list:
if not isinstance(v, torch.Tensor):
raise RuntimeError(
"CpuCommunicator only supports sending tensors.")
size_list.append(v.size())
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
dtype=torch.uint8)
value_list.append(key_size_tensor)
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
return None
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
value_list: list[torch.Tensor] = tensor_list[:-1]
key_size_tensor = tensor_list[-1]
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
key_list = key_size[0]
size_list = key_size[1]
assert len(key_list) == len(size_list)
assert len(key_list) == len(value_list)
tensor_dict: dict[str, torch.Tensor] = {}
for key, size, t in zip(key_list, size_list, value_list):
tensor_dict[key] = t.view(size)
return tensor_dict
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
from torch.distributed import ProcessGroup
......@@ -142,6 +142,42 @@ class CudaCommunicator(DeviceCommunicatorBase):
# Reshape before returning
return output.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
if sizes is not None:
assert len(sizes) == world_size
assert input_tensor.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
......@@ -180,6 +216,51 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager.destroy()
self.all2all_manager = None
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None and not pynccl_comm.disabled
# 'sizes' is not needed if all inputs in the same group have the same
# shape
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None
def _all_gather_single(input_: torch.Tensor,
sizes: Optional[list[int]] = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group]
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
if sizes is not None:
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
else:
pynccl_comm.all_gather(output_tensor, input_)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401
class HpuCommunicator(DeviceCommunicatorBase):
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
......@@ -152,6 +152,40 @@ class PyNcclCommunicator:
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))
def all_gatherv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
assert output_tensor.shape[0] == sum(sizes)
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
......@@ -174,6 +208,38 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def reduce_scatterv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()), chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
cudaStream_t(stream.cuda_stream))
split_offset += split_size
self.nccl.ncclGroupEnd()
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
......@@ -216,3 +282,9 @@ class PyNcclCommunicator:
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()
......@@ -154,6 +154,17 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
......@@ -207,6 +218,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
]
# class attribute to store the mapping from the path to the library
......@@ -300,6 +315,18 @@ class NCCLLibrary:
datatype, op, comm,
stream))
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: ncclComm_t, stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream))
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
......@@ -342,6 +369,12 @@ class NCCLLibrary:
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
......
......@@ -53,3 +53,6 @@ class XpuCommunicator(DeviceCommunicatorBase):
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
......@@ -29,12 +29,15 @@ physical experts.
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch.distributed import all_gather, all_reduce
from torch.distributed import ProcessGroup, all_gather, all_reduce
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
in_the_same_node_as)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
......@@ -172,6 +175,9 @@ class EplbState:
model: MixtureOfExperts,
device: torch.device,
parallel_config: ParallelConfig,
global_expert_load: Optional[torch.Tensor] = None,
old_global_expert_indices: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None,
) -> "EplbState":
"""
Build the initial EPLB state.
......@@ -185,8 +191,16 @@ class EplbState:
physical_to_logical_map_list,
device=device,
)
# Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now.
# TODO(rui): make this configurable
MAX_EXPERT_REDUNDANCY = 1023
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
f"num_redundant_experts {model.num_redundant_experts} "
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
logical_to_physical_map = torch.full(
(model.num_logical_experts, model.num_redundant_experts + 1),
(model.num_logical_experts, max_slots_per_logical_expert),
-1,
device=device,
)
......@@ -235,11 +249,63 @@ class EplbState:
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}")
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = (rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
))
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
expert_rearrangement_step = 0
return cls(
physical_to_logical_map,
......@@ -337,7 +403,10 @@ class EplbState:
def rearrange(self,
model: MixtureOfExperts,
is_profile: bool = False) -> None:
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None) -> None:
"""
Rearrange the experts according to the current load.
"""
......@@ -353,42 +422,79 @@ class EplbState:
logger.info("Rearranging experts %s...",
"(profile)" if is_profile else "")
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
if global_expert_load is None:
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
metadata = torch.tensor(
[
model.num_moe_layers, model.num_logical_experts,
self.physical_to_logical_map.shape[1]
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(metadata,
group=get_ep_group().cpu_group,
group_src=0)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = self.physical_to_logical_map
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group,
group_src=0)
return global_expert_load_window
else:
assert execute_shuffle
global_expert_load_window = global_expert_load
# TODO(bowen): Treat differently for prefill and decode nodes
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
num_gpus = sum(new_rank != -1
for new_rank in rank_mapping.values())
num_replicas = num_replicas // ep_group.size(
) * num_gpus # handle num replicas change
else:
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
self.num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
......@@ -414,10 +520,24 @@ class EplbState:
model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
if self.physical_to_logical_map.shape[
1] != new_physical_to_logical_map.shape[1]:
self.physical_to_logical_map = new_physical_to_logical_map.to(
self.physical_to_logical_map.device)
else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0,
self.logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
self.logical_replica_count.copy_(new_logical_replica_count)
......@@ -430,3 +550,69 @@ class EplbState:
" (profile) " if is_profile else " ",
time_end - time_start,
)
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata,
group=ep_group.cpu_group,
group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist())
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group.device_group,
group_src=0)
return global_expert_load, old_global_expert_indices
def _node_count_with_rank_mapping(
pg: Union[ProcessGroup, StatelessProcessGroup],
rank_mapping: dict[int, int],
) -> int:
if isinstance(pg, ProcessGroup):
world_size = torch.distributed.get_world_size(group=pg)
else:
world_size = pg.world_size
if world_size == 1:
return 1
# Build node assignment map
node_assignment = [0] * world_size # rank -> node_id
next_node_id = 0
for current_rank in range(world_size):
if node_assignment[current_rank] != 0:
continue # Already assigned to a node
assert current_rank in rank_mapping
if rank_mapping[current_rank] == -1:
continue # Pending shutdown
# Assign current rank to a new node
next_node_id += 1
node_assignment[current_rank] = next_node_id
# Find all ranks on the same node as current_rank
same_node_flags = in_the_same_node_as(pg, current_rank)
for other_rank, is_same_node in enumerate(same_node_flags):
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id
......@@ -8,6 +8,7 @@ This involves the exchange of expert weights between GPUs.
from collections.abc import Iterable, MutableSequence, Sequence
from functools import partial
from typing import Optional
import torch
from torch.distributed import (P2POp, ProcessGroup, all_gather,
......@@ -127,6 +128,8 @@ def shuffle_layer(
dst_global = local2global(dst)
if is_received_locally[dst]:
continue
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
continue
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights,
......@@ -139,6 +142,8 @@ def shuffle_layer(
experts_send_loc: dict[int, int] = {}
for src in range(num_local_experts):
expert = old_indices[local2global(src)]
if expert == -1:
continue
if expert in experts_send_loc:
continue
experts_send_loc[expert] = src
......@@ -181,6 +186,8 @@ def shuffle_layer(
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
......@@ -227,6 +234,8 @@ def shuffle_layer(
weight[dst].copy_(buffer[dst])
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[src])
......@@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace(
expert_weights: Sequence[Iterable[torch.Tensor]],
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: Optional[dict[int, int]] = None,
) -> None:
"""
Rearranges the expert weights in place according to the new expert indices.
......@@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace(
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
rank_mapping: A dictionary mapping old rank to new rank.
"""
if rank_mapping is not None:
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = \
_map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = \
_map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[
1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
......@@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace(
)
def _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
new_ep_size: int,
) -> torch.Tensor:
"""
Map the old global expert indices to the new global expert indices.
Args:
old_global_expert_indices:
Shape (num_layers, old_ep_size * num_local_physical_experts).
rank_mapping: Mapping from old rank to new rank.
new_ep_size: New expert parallelism size.
Returns:
Mapped expert indices with shape
(num_layers, new_ep_size * num_local_physical_experts).
"""
num_layers, old_num_physical_experts = old_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
num_local_physical_experts = old_num_physical_experts // old_ep_size
new_num_physical_experts = new_ep_size * num_local_physical_experts
# Create mapped tensor with new shape, initialized to -1
mapped_expert_indices = torch.full(
(num_layers, new_num_physical_experts),
fill_value=-1,
dtype=old_global_expert_indices.dtype,
device=old_global_expert_indices.device,
)
# Handle rank mapping (scale up/down with rank changes)
for old_rank in range(old_ep_size):
new_rank = rank_mapping.get(old_rank)
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
# This old rank exists in the new configuration
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
old_global_expert_indices[:, old_start_idx:old_end_idx]
# If new_rank is None or >= new_ep_size, the experts remain -1
# (scale down case)
return mapped_expert_indices
def _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
) -> torch.Tensor:
num_layers, new_num_physical_experts = new_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_local_physical_experts = new_num_physical_experts // new_ep_size
old_num_physical_experts = old_ep_size * num_local_physical_experts
mapped_expert_indices = torch.full(
(num_layers, old_num_physical_experts),
fill_value=-1,
dtype=new_global_expert_indices.dtype,
device=new_global_expert_indices.device,
)
for old_rank in range(old_ep_size):
new_rank = rank_mapping[old_rank]
if new_rank >= 0 and new_rank < new_ep_size:
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
new_global_expert_indices[:, new_start_idx:new_end_idx]
return mapped_expert_indices
__all__ = ["rearrange_expert_weights_inplace"]
......@@ -3,12 +3,18 @@
"""
KV cache helper for store.
"""
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Optional, cast
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
......@@ -107,3 +113,87 @@ def get_kv_connector_cache_layout():
"layout to HND for better xfer performance.")
return "HND"
return "NHD"
class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int):
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# select output of the worker specified by output_rank
output = outputs[output_rank]
# set the aggregated finished_sending / finished_recving
# if output.finished_sending/recving is not empty, but the other ranks
# still have unfinished send/recv, we want to set the aggregated
# finished_sending/recving to None until all ranks have finished
# send/recv
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
return output
def async_aggregate(self,
output_futures: Sequence[Future[ModelRunnerOutput]],
output_rank: int = 0) -> Future[ModelRunnerOutput]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(cast(list[ModelRunnerOutput], outputs),
output_rank))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future
......@@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVConnectorMetadata:
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
......@@ -71,7 +71,7 @@ class KVConnectorBase_V1(ABC):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata()
self._connector_metadata: Optional[KVConnectorMetadata] = None
self._vllm_config = vllm_config
self._role = role
......@@ -102,7 +102,7 @@ class KVConnectorBase_V1(ABC):
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = KVConnectorMetadata()
self._connector_metadata = None
def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
......@@ -112,6 +112,9 @@ class KVConnectorBase_V1(ABC):
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
......@@ -190,7 +193,9 @@ class KVConnectorBase_V1(ABC):
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment