Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -114,7 +114,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -114,7 +114,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
return spec_decode_worker return spec_decode_worker
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
class SpecDecodeWorker(LoRANotSupportedWorkerBase): class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"""Worker which implements speculative decoding. """Worker which implements speculative decoding.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.envs import VLLM_USE_MODELSCOPE from vllm import envs
if VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
try: try:
# Patch here, before each import happens # Patch here, before each import happens
import modelscope import modelscope
......
...@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import ( ...@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm.envs import VLLM_USE_MODELSCOPE from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -45,13 +45,12 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, ...@@ -45,13 +45,12 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
if VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig from modelscope import AutoConfig
else: else:
from transformers import AutoConfig from transformers import AutoConfig
MISTRAL_CONFIG_NAME = "params.json" MISTRAL_CONFIG_NAME = "params.json"
HF_TOKEN = os.getenv('HF_TOKEN', None)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -130,7 +129,7 @@ def list_repo_files( ...@@ -130,7 +129,7 @@ def list_repo_files(
] ]
# if model is remote, use hf_hub api to list files # if model is remote, use hf_hub api to list files
try: try:
if VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
from vllm.transformers_utils.utils import ( from vllm.transformers_utils.utils import (
modelscope_list_repo_files) modelscope_list_repo_files)
return modelscope_list_repo_files(repo_id, return modelscope_list_repo_files(repo_id,
...@@ -185,7 +184,7 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, ...@@ -185,7 +184,7 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
return file_exists(str(model), return file_exists(str(model),
config_name, config_name,
revision=revision, revision=revision,
token=HF_TOKEN) token=os.getenv('HF_TOKEN', None))
def patch_rope_scaling(config: PretrainedConfig) -> None: def patch_rope_scaling(config: PretrainedConfig) -> None:
...@@ -300,7 +299,10 @@ def get_config( ...@@ -300,7 +299,10 @@ def get_config(
" - For Hugging Face models: ensure the presence of a " " - For Hugging Face models: ensure the presence of a "
"'config.json'.\n" "'config.json'.\n"
" - For Mistral models: ensure the presence of a " " - For Mistral models: ensure the presence of a "
"'params.json'.\n").format(model=model) "'params.json'.\n"
"3. For GGUF: pass the local path of the GGUF checkpoint.\n"
" Loading GGUF from a remote repo directly is not yet "
"supported.\n").format(model=model)
raise ValueError(error_message) from e raise ValueError(error_message) from e
...@@ -309,7 +311,7 @@ def get_config( ...@@ -309,7 +311,7 @@ def get_config(
model, model,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=HF_TOKEN, token=os.getenv('HF_TOKEN', None),
**kwargs, **kwargs,
) )
...@@ -321,7 +323,7 @@ def get_config( ...@@ -321,7 +323,7 @@ def get_config(
model, model,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=HF_TOKEN, token=os.getenv('HF_TOKEN', None),
**kwargs, **kwargs,
) )
else: else:
...@@ -331,7 +333,7 @@ def get_config( ...@@ -331,7 +333,7 @@ def get_config(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=HF_TOKEN, token=os.getenv('HF_TOKEN', None),
**kwargs, **kwargs,
) )
except ValueError as e: except ValueError as e:
...@@ -349,7 +351,7 @@ def get_config( ...@@ -349,7 +351,7 @@ def get_config(
raise e raise e
elif config_format == ConfigFormat.MISTRAL: elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=HF_TOKEN, **kwargs) config = load_params_config(model, revision, **kwargs)
else: else:
supported_formats = [ supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
...@@ -558,7 +560,7 @@ def get_sentence_transformer_tokenizer_config(model: str, ...@@ -558,7 +560,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
# If model is on HuggingfaceHub, get the repo files # If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files(model, repo_files = list_repo_files(model,
revision=revision, revision=revision,
token=HF_TOKEN) token=os.getenv('HF_TOKEN', None))
except Exception: except Exception:
repo_files = [] repo_files = []
...@@ -765,7 +767,7 @@ def get_hf_image_processor_config( ...@@ -765,7 +767,7 @@ def get_hf_image_processor_config(
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> dict[str, Any]:
# ModelScope does not provide an interface for image_processor # ModelScope does not provide an interface for image_processor
if VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
return dict() return dict()
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
if check_gguf_file(model): if check_gguf_file(model):
......
...@@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig): ...@@ -52,13 +52,15 @@ class EAGLEConfig(PretrainedConfig):
assert self.model is not None, \ assert self.model is not None, \
"model should not be None when method is eagle" "model should not be None when method is eagle"
kwargs["architectures"] = [ kwargs["architectures"] = [
f"Eagle{arch}" for arch in self.model.architectures f"Eagle{arch}" if not arch.startswith("Eagle") \
else arch for arch in self.model.architectures
] ]
elif method == "eagle3": elif method == "eagle3":
assert self.model is not None, \ assert self.model is not None, \
"model should not be None when method is eagle3" "model should not be None when method is eagle3"
kwargs["architectures"] = [ kwargs["architectures"] = [
f"Eagle3{arch}" for arch in self.model.architectures f"Eagle3{arch}" if not arch.startswith("Eagle3") \
else arch for arch in self.model.architectures
] ]
else: else:
raise ValueError(f"Invalid method {method}. \ raise ValueError(f"Invalid method {method}. \
......
...@@ -33,6 +33,8 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, ...@@ -33,6 +33,8 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack) Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.multimodal.image import convert_image_mode
__all__ = ['OvisProcessor'] __all__ = ['OvisProcessor']
IGNORE_ID = -100 IGNORE_ID = -100
...@@ -361,8 +363,8 @@ class OvisProcessor(ProcessorMixin): ...@@ -361,8 +363,8 @@ class OvisProcessor(ProcessorMixin):
# pick the partition with maximum covering_ratio and break the tie using #sub_images # pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB': if convert_to_rgb:
image = image.convert('RGB') image = convert_image_mode(image, 'RGB')
sides = self.get_image_size() sides = self.get_image_size()
......
...@@ -13,7 +13,7 @@ import huggingface_hub ...@@ -13,7 +13,7 @@ import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_base import (TokenizerBase, from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
...@@ -168,7 +168,7 @@ def get_tokenizer( ...@@ -168,7 +168,7 @@ def get_tokenizer(
) -> AnyTokenizer: ) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope. """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
""" """
if VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub, # download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use. # lazy import so that modelscope is not required for normal use.
# pylint: disable=C. # pylint: disable=C.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
import huggingface_hub import huggingface_hub
import regex as re
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -156,7 +156,11 @@ def make_mistral_chat_completion_request( ...@@ -156,7 +156,11 @@ def make_mistral_chat_completion_request(
# #
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for message in messages: for message in messages:
if message.get("role") == "assistant": # Remove reasoning_content as unsupported by Mistral
_ = message.pop("reasoning_content", None) # type: ignore
# Convert list text content to string
if message.get("role") in ("assistant", "tool"):
content = message.get("content") content = message.get("content")
if isinstance(content, list): if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content) content = "\n".join(chunk.get("text") for chunk in content)
......
...@@ -19,7 +19,6 @@ import json ...@@ -19,7 +19,6 @@ import json
import multiprocessing import multiprocessing
import os import os
import pickle import pickle
import re
import signal import signal
import socket import socket
import subprocess import subprocess
...@@ -34,7 +33,8 @@ import uuid ...@@ -34,7 +33,8 @@ import uuid
import warnings import warnings
import weakref import weakref
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
ArgumentTypeError, _ArgumentGroup) ArgumentTypeError, RawDescriptionHelpFormatter,
_ArgumentGroup)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
...@@ -54,6 +54,7 @@ import cloudpickle ...@@ -54,6 +54,7 @@ import cloudpickle
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import psutil import psutil
import regex as re
import torch import torch
import torch.types import torch.types
import yaml import yaml
...@@ -77,9 +78,15 @@ if TYPE_CHECKING: ...@@ -77,9 +78,15 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
# Exception strings for non-implemented encoder/decoder scenarios # Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA = \ STR_NOT_IMPL_ENC_DEC_SWA = \
...@@ -752,16 +759,15 @@ def get_kv_cache_torch_dtype( ...@@ -752,16 +759,15 @@ def get_kv_cache_torch_dtype(
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
if isinstance(cache_dtype, str): if isinstance(cache_dtype, str):
if cache_dtype == "auto": if cache_dtype == "auto":
if isinstance(model_dtype, str): if isinstance(model_dtype,
str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype): elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype torch_dtype = model_dtype
else: else:
raise ValueError(f"Invalid model dtype: {model_dtype}") raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]: elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8":
torch_dtype = torch.uint8
else: else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype): elif isinstance(cache_dtype, torch.dtype):
...@@ -998,7 +1004,7 @@ def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: ...@@ -998,7 +1004,7 @@ def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
""" """
Unlike {class}`itertools.groupby`, groups are not broken by Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data. non-contiguous data.
""" """
groups = defaultdict[_K, list[_V]](list) groups = defaultdict[_K, list[_V]](list)
...@@ -1318,7 +1324,8 @@ class StoreBoolean(Action): ...@@ -1318,7 +1324,8 @@ class StoreBoolean(Action):
"Expected 'true' or 'false'.") "Expected 'true' or 'false'.")
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter): class SortedHelpFormatter(ArgumentDefaultsHelpFormatter,
RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings.""" """SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width): def _split_lines(self, text, width):
...@@ -1919,11 +1926,11 @@ class _PlaceholderBase: ...@@ -1919,11 +1926,11 @@ class _PlaceholderBase:
Disallows downstream usage of placeholder modules. Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because We need to explicitly override each dunder method because
{meth}`__getattr__` is not called when they are accessed. [`__getattr__`][vllm.utils._PlaceholderBase.__getattr__]
is not called when they are accessed.
:::{seealso} Info:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
:::
""" """
def __getattr__(self, key: str) -> Never: def __getattr__(self, key: str) -> Never:
...@@ -2522,7 +2529,7 @@ def _maybe_force_spawn(): ...@@ -2522,7 +2529,7 @@ def _maybe_force_spawn():
logger.warning( logger.warning(
"We must use the `spawn` multiprocessing start method. " "We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/" "See https://docs.vllm.ai/en/latest/usage/"
"troubleshooting.html#python-multiprocessing " "troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s", reason) "for more information. Reason: %s", reason)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
...@@ -2787,14 +2794,17 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): ...@@ -2787,14 +2794,17 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
# Only relevant for models using ALiBi (e.g, MPT) # Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool: def check_use_alibi(model_config: ModelConfig) -> bool:
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon cfg = model_config.hf_text_config
return (getattr(cfg, "alibi", False) # Falcon
or ("BloomForCausalLM" in getattr(model_config.hf_config, or ("BloomForCausalLM" in getattr(model_config.hf_config,
"architectures", [])) # Bloom "architectures", [])) # Bloom
or getattr(model_config.hf_text_config, "position_encoding_type", or getattr(cfg, "position_encoding_type", "") ==
"") == "alibi" # codellm_1b_alibi "alibi" # codellm_1b_alibi
or or (hasattr(cfg, "attn_config") # MPT
(hasattr(model_config.hf_text_config, "attn_config") # MPT and ((isinstance(cfg.attn_config, dict)
and model_config.hf_text_config.attn_config.get("alibi", False))) and cfg.attn_config.get("alibi", False)) or
(not isinstance(cfg.attn_config, dict)
and getattr(cfg.attn_config, "alibi", False)))))
def sha256(input) -> int: def sha256(input) -> int:
......
...@@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): ...@@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The number of entries in the last page of each request in # The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size] # the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None paged_kv_last_page_len: Optional[torch.Tensor] = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: Optional[torch.Tensor] = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
...@@ -75,27 +77,33 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -75,27 +77,33 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.kv_cache_spec.block_size page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
mask = (torch.arange(block_table.size(1), mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype, dtype=block_table.dtype,
device=block_table.device).unsqueeze(0) device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1)) < block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask] paged_kv_indices = block_table[mask]
paged_kv_indptr = torch.cat([ paged_kv_indptr = torch.cat([
torch.zeros(1, torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32) block_table_bounds.cumsum(dim=0, dtype=torch.int32)
]) ])
paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len) page_size, paged_kv_last_page_len)
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
return ( return (
paged_kv_indices, paged_kv_indices,
paged_kv_indptr, paged_kv_indptr,
paged_kv_last_page_len, paged_kv_last_page_len,
qo_indptr,
) )
def _build_decode(self, block_table_tensor: torch.Tensor, def _build_decode(self, block_table_tensor: torch.Tensor,
...@@ -105,6 +113,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -105,6 +113,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_indices, paged_kv_indices,
paged_kv_indptr, paged_kv_indptr,
paged_last_page_len, paged_last_page_len,
qo_indptr,
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
attn_metadata = AiterMLADecodeMetadata( attn_metadata = AiterMLADecodeMetadata(
...@@ -112,7 +121,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -112,7 +121,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens=seq_lens, seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr, paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices, paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len) paged_kv_last_page_len=paged_last_page_len,
qo_indptr=qo_indptr)
return attn_metadata return attn_metadata
...@@ -137,7 +147,10 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -137,7 +147,10 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type, blocksparse_params, logits_soft_cap, attn_type,
**mla_args) **mla_args)
assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value.")
unsupported_features = [ unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
] ]
...@@ -189,7 +202,18 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -189,7 +202,18 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
if self.num_heads == 16:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo = 1
else:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert attn_metadata.prefill is not None
max_seqlen_qo = attn_metadata.prefill.max_query_len
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len) attn_metadata.decode.paged_kv_last_page_len)
......
...@@ -174,6 +174,7 @@ class KVCacheManager: ...@@ -174,6 +174,7 @@ class KVCacheManager:
num_new_tokens: int, num_new_tokens: int,
num_new_computed_tokens: int = 0, num_new_computed_tokens: int = 0,
new_computed_blocks: Optional[KVCacheBlocks] = None, new_computed_blocks: Optional[KVCacheBlocks] = None,
num_draft_tokens: int = 0,
num_lookahead_tokens: int = 0, num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False, delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]: ) -> Optional[KVCacheBlocks]:
...@@ -273,7 +274,7 @@ class KVCacheManager: ...@@ -273,7 +274,7 @@ class KVCacheManager:
# generated (accepted) tokens. # generated (accepted) tokens.
self.single_type_manager.cache_blocks( self.single_type_manager.cache_blocks(
request, self.req_to_block_hashes[request.request_id], request, self.req_to_block_hashes[request.request_id],
num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) num_computed_tokens + num_new_tokens - num_draft_tokens)
return KVCacheBlocks(new_blocks) return KVCacheBlocks(new_blocks)
......
...@@ -227,10 +227,15 @@ class Scheduler(SchedulerInterface): ...@@ -227,10 +227,15 @@ class Scheduler(SchedulerInterface):
req_index += 1 req_index += 1
continue continue
num_draft_tokens = max(
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens, num_new_tokens,
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens) num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
...@@ -310,15 +315,16 @@ class Scheduler(SchedulerInterface): ...@@ -310,15 +315,16 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting[0] request = self.waiting[0]
num_prealloc_computed_tokens = 0
# P/D: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
if is_ready: if is_ready:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
num_prealloc_computed_tokens = (
request.num_computed_tokens)
else: else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.popleft() self.waiting.popleft()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.appendleft(request)
continue continue
...@@ -349,8 +355,9 @@ class Scheduler(SchedulerInterface): ...@@ -349,8 +355,9 @@ class Scheduler(SchedulerInterface):
load_kv_async = False load_kv_async = False
# Get already-cached tokens. # Get already-cached tokens.
if num_prealloc_computed_tokens == 0: if request.num_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \ # Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks( self.kv_cache_manager.get_computed_blocks(
request) request)
...@@ -358,23 +365,22 @@ class Scheduler(SchedulerInterface): ...@@ -358,23 +365,22 @@ class Scheduler(SchedulerInterface):
if self.connector is not None: if self.connector is not None:
num_external_computed_tokens, load_kv_async = ( num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens)) request, num_new_local_computed_tokens))
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens + num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens) num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else: else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty() new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0 num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
# Total computed tokens (allocated in prior step).
num_computed_tokens = num_prealloc_computed_tokens
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work. # KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async: if load_kv_async:
assert num_external_computed_tokens > 0 assert num_external_computed_tokens > 0
num_new_tokens = 0 num_new_tokens = 0
...@@ -405,7 +411,7 @@ class Scheduler(SchedulerInterface): ...@@ -405,7 +411,7 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens + num_external_computed_tokens, num_new_tokens + num_external_computed_tokens,
num_native_computed_tokens, num_new_local_computed_tokens,
new_computed_blocks, new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens, num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async, delay_cache_blocks=load_kv_async,
...@@ -457,7 +463,9 @@ class Scheduler(SchedulerInterface): ...@@ -457,7 +463,9 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
# Count the number of prifix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request.request_id] = (
...@@ -799,6 +807,7 @@ class Scheduler(SchedulerInterface): ...@@ -799,6 +807,7 @@ class Scheduler(SchedulerInterface):
stop_reason=request.stop_reason, stop_reason=request.stop_reason,
events=request.take_events(), events=request.take_events(),
kv_transfer_params=kv_transfer_params, kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
)) ))
else: else:
......
...@@ -107,6 +107,9 @@ class EngineCoreOutput( ...@@ -107,6 +107,9 @@ class EngineCoreOutput(
events: Optional[list[EngineCoreEvent]] = None events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None kv_transfer_params: Optional[dict[str, Any]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
@property @property
def finished(self) -> bool: def finished(self) -> bool:
return self.finish_reason is not None return self.finish_reason is not None
......
...@@ -20,6 +20,8 @@ from vllm.outputs import RequestOutput ...@@ -20,6 +20,8 @@ from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -80,6 +82,9 @@ class AsyncLLM(EngineClient): ...@@ -80,6 +82,9 @@ class AsyncLLM(EngineClient):
"AsyncLLMEngine.from_vllm_config(...) or explicitly set " "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.") "VLLM_USE_V1=0 or 1 and report this issue on Github.")
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.log_requests = log_requests self.log_requests = log_requests
......
...@@ -57,6 +57,10 @@ class EngineCore: ...@@ -57,6 +57,10 @@ class EngineCore:
executor_fail_callback: Optional[Callable] = None): executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling" assert vllm_config.model_config.runner_type != "pooling"
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins
load_general_plugins()
self.vllm_config = vllm_config self.vllm_config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s", logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config) VLLM_VERSION, vllm_config)
...@@ -336,6 +340,13 @@ class EngineCore: ...@@ -336,6 +340,13 @@ class EngineCore:
return self.model_executor.collective_rpc(method, timeout, args, return self.model_executor.collective_rpc(method, timeout, args,
kwargs) kwargs)
def save_tensorized_model(
self,
tensorizer_config,
) -> None:
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
...@@ -706,7 +717,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -706,7 +717,7 @@ class DPEngineCoreProc(EngineCoreProc):
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size)) world_size))
self.local_dp_rank = local_dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0 self.current_wave = 0
...@@ -779,7 +790,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -779,7 +790,7 @@ class DPEngineCoreProc(EngineCoreProc):
local_unfinished_reqs) local_unfinished_reqs)
if not self.engines_running: if not self.engines_running:
if self.local_dp_rank == 0: if self.dp_rank == 0:
# Notify client that we are pausing the loop. # Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.", logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave) self.current_wave)
......
...@@ -27,7 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor ...@@ -27,7 +27,10 @@ from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -64,6 +67,11 @@ class LLMEngine: ...@@ -64,6 +67,11 @@ class LLMEngine:
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
self.stat_logger: Optional[StatLoggerBase] = None
if self.log_stats:
self.stat_logger = PrometheusStatLogger(vllm_config)
# important: init dp group before init the engine_core # important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc. # In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
...@@ -86,7 +94,7 @@ class LLMEngine: ...@@ -86,7 +94,7 @@ class LLMEngine:
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer, self.output_processor = OutputProcessor(self.tokenizer,
log_stats=False) log_stats=self.log_stats)
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client( self.engine_core = EngineCoreClient.make_client(
...@@ -94,7 +102,7 @@ class LLMEngine: ...@@ -94,7 +102,7 @@ class LLMEngine:
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=False, # FIXME: implement log_stats=self.log_stats,
) )
if not multiprocess_mode: if not multiprocess_mode:
...@@ -223,12 +231,21 @@ class LLMEngine: ...@@ -223,12 +231,21 @@ class LLMEngine:
outputs = self.engine_core.get_output() outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs. # 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs( processed_outputs = self.output_processor.process_outputs(
outputs.outputs) outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats)
# 3) Abort any reqs that finished due to stop strings. # 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort) self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.stat_logger is not None:
assert outputs.scheduler_stats is not None
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats)
return processed_outputs.request_outputs return processed_outputs.request_outputs
def get_vllm_config(self): def get_vllm_config(self):
...@@ -260,6 +277,10 @@ class LLMEngine: ...@@ -260,6 +277,10 @@ class LLMEngine:
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
def get_metrics(self) -> list[Metric]:
assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot()
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
......
...@@ -147,6 +147,7 @@ class RequestState: ...@@ -147,6 +147,7 @@ class RequestState:
finish_reason: Optional[FinishReason], finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None], stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None, kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]: ) -> Optional[RequestOutput]:
finished = finish_reason is not None finished = finish_reason is not None
...@@ -169,7 +170,7 @@ class RequestState: ...@@ -169,7 +170,7 @@ class RequestState:
return None return None
return self._new_request_output(request_id, outputs, finished, return self._new_request_output(request_id, outputs, finished,
kv_transfer_params) kv_transfer_params, num_cached_tokens)
def _new_request_output( def _new_request_output(
self, self,
...@@ -177,6 +178,7 @@ class RequestState: ...@@ -177,6 +178,7 @@ class RequestState:
outputs: list[CompletionOutput], outputs: list[CompletionOutput],
finished: bool, finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None, kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput: ) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA: if self.output_kind == RequestOutputKind.DELTA:
...@@ -193,6 +195,7 @@ class RequestState: ...@@ -193,6 +195,7 @@ class RequestState:
outputs=outputs, outputs=outputs,
finished=finished, finished=finished,
kv_transfer_params=kv_transfer_params, kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
) )
def _new_completion_output( def _new_completion_output(
...@@ -340,7 +343,7 @@ class OutputProcessor: ...@@ -340,7 +343,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks. # 2) Detokenize the token ids into text and perform stop checks.
...@@ -356,7 +359,7 @@ class OutputProcessor: ...@@ -356,7 +359,7 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects. # 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output( if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason, new_token_ids, finish_reason, stop_reason,
kv_transfer_params): kv_transfer_params, num_cached_tokens):
if req_state.queue is not None: if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate(). # AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output) req_state.queue.put(request_output)
......
...@@ -38,7 +38,7 @@ logger = init_logger(__name__) ...@@ -38,7 +38,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
EXECUTE_MODEL_TIMEOUT_S = 40 EXECUTE_MODEL_TIMEOUT_S = 300
class MultiprocExecutor(Executor): class MultiprocExecutor(Executor):
...@@ -50,6 +50,7 @@ class MultiprocExecutor(Executor): ...@@ -50,6 +50,7 @@ class MultiprocExecutor(Executor):
self.is_failed = False self.is_failed = False
self.shutdown_event = threading.Event() self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None self.failure_callback: Optional[FailureCallback] = None
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
self.world_size = self.parallel_config.world_size self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tensor_parallel_size = self.parallel_config.tensor_parallel_size
...@@ -107,7 +108,6 @@ class MultiprocExecutor(Executor): ...@@ -107,7 +108,6 @@ class MultiprocExecutor(Executor):
# For pipeline parallel, we use a thread pool for asynchronous # For pipeline parallel, we use a thread pool for asynchronous
# execute_model. # execute_model.
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1: if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence # Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue # from the response queue
......
...@@ -200,24 +200,24 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -200,24 +200,24 @@ class PrometheusStatLogger(StatLoggerBase):
# Counters # Counters
# #
self.counter_num_preempted_reqs = self._counter_cls( self.counter_num_preempted_reqs = self._counter_cls(
name="vllm:num_preemptions_total", name="vllm:num_preemptions",
documentation="Cumulative number of preemption from the engine.", documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_prompt_tokens = self._counter_cls( self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_generation_tokens = self._counter_cls( self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total", name="vllm:generation_tokens",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_request_success: dict[FinishReason, self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {} prometheus_client.Counter] = {}
counter_request_success_base = self._counter_cls( counter_request_success_base = self._counter_cls(
name="vllm:request_success_total", name="vllm:request_success",
documentation="Count of successfully processed requests.", documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"]) labelnames=labelnames + ["finished_reason"])
for reason in FinishReason: for reason in FinishReason:
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
from prometheus_client import REGISTRY
from prometheus_client import Metric as PromMetric
from prometheus_client.samples import Sample
@dataclass
class Metric:
"""A base class for prometheus metrics.
Each metric may be associated with key=value labels, and
in some cases a single vLLM instance may have multiple
metrics with the same name but different sets of labels.
"""
name: str
labels: dict[str, str]
@dataclass
class Counter(Metric):
"""A monotonically increasing integer counter."""
value: int
@dataclass
class Vector(Metric):
"""An ordered array of integer counters.
This type - which doesn't exist in Prometheus - models one very
specific metric, vllm:spec_decode_num_accepted_tokens_per_pos.
"""
values: list[int]
@dataclass
class Gauge(Metric):
"""A numerical value that can go up or down."""
value: float
@dataclass
class Histogram(Metric):
"""Observations recorded in configurable buckets.
Buckets are represented by a dictionary. The key is
the upper limit of the bucket, and the value is the
observed count in that bucket. A '+Inf' key always
exists.
The count property is the total count across all
buckets, identical to the count of the '+Inf' bucket.
The sum property is the total sum of all observed
values.
"""
count: int
sum: float
buckets: dict[str, int]
def get_metrics_snapshot() -> list[Metric]:
"""An API for accessing in-memory Prometheus metrics.
Example:
>>> for metric in llm.get_metrics():
... if isinstance(metric, Counter):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Gauge):
... print(f"{metric} = {metric.value}")
... elif isinstance(metric, Histogram):
... print(f"{metric}")
... print(f" sum = {metric.sum}")
... print(f" count = {metric.count}")
... for bucket_le, value in metrics.buckets.items():
... print(f" {bucket_le} = {value}")
"""
collected: list[Metric] = []
for metric in REGISTRY.collect():
if not metric.name.startswith("vllm:"):
continue
if metric.type == "gauge":
samples = _get_samples(metric)
for s in samples:
collected.append(
Gauge(name=metric.name, labels=s.labels, value=s.value))
elif metric.type == "counter":
samples = _get_samples(metric, "_total")
if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
#
# Ugly vllm:num_accepted_tokens_per_pos special case.
#
# This metric is a vector of counters - for each spec
# decoding token position, we observe the number of
# accepted tokens using a Counter labeled with 'position'.
# We convert these into a vector of integer values.
#
for labels, values in _digest_num_accepted_by_pos_samples(
samples):
collected.append(
Vector(name=metric.name, labels=labels, values=values))
else:
for s in samples:
collected.append(
Counter(name=metric.name,
labels=s.labels,
value=int(s.value)))
elif metric.type == "histogram":
#
# A histogram has a number of '_bucket' samples where
# the 'le' label represents the upper limit of the bucket.
# We convert these bucketized values into a dict of values
# indexed by the value of the 'le' label. The 'le=+Inf'
# label is a special case, catching all values observed.
#
bucket_samples = _get_samples(metric, "_bucket")
count_samples = _get_samples(metric, "_count")
sum_samples = _get_samples(metric, "_sum")
for labels, buckets, count_value, sum_value in _digest_histogram(
bucket_samples, count_samples, sum_samples):
collected.append(
Histogram(name=metric.name,
labels=labels,
buckets=buckets,
count=count_value,
sum=sum_value))
else:
raise AssertionError(f"Unknown metric type {metric.type}")
return collected
def _get_samples(metric: PromMetric,
suffix: Optional[str] = None) -> list[Sample]:
name = (metric.name + suffix) if suffix is not None else metric.name
return [s for s in metric.samples if s.name == name]
def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]:
labels_copy = labels.copy()
labels_copy.pop(key_to_remove)
return labels_copy
def _digest_histogram(
bucket_samples: list[Sample], count_samples: list[Sample],
sum_samples: list[Sample]
) -> list[tuple[dict[str, str], dict[str, int], int, float]]:
#
# In the case of DP, we have an indigestable
# per-bucket-per-engine count as a list of labelled
# samples, along with total and sum samples
#
# bucket_samples (in):
# labels = {bucket: 100, idx: 0}, value = 2
# labels = {bucket: 200, idx: 0}, value = 4
# labels = {bucket: Inf, idx: 0}, value = 10
# labels = {bucket: 100, idx: 1}, value = 1
# labels = {bucket: 200, idx: 2}, value = 5
# labels = {bucket: Inf, idx: 3}, value = 7
# count_samples (in):
# labels = {idx: 0}, value = 10
# labels = {idx: 1}, value = 7
# sum_samples (in):
# labels = {idx: 0}, value = 2000
# labels = {idx: 1}, value = 1200
#
# output: [
# {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000
# {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200
# ]
buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {}
for s in bucket_samples:
bucket = s.labels["le"]
labels_key = frozenset(_strip_label(s.labels, "le").items())
if labels_key not in buckets_by_labels:
buckets_by_labels[labels_key] = {}
buckets_by_labels[labels_key][bucket] = int(s.value)
counts_by_labels: dict[frozenset[tuple[str, str]], int] = {}
for s in count_samples:
labels_key = frozenset(s.labels.items())
counts_by_labels[labels_key] = int(s.value)
sums_by_labels: dict[frozenset[tuple[str, str]], float] = {}
for s in sum_samples:
labels_key = frozenset(s.labels.items())
sums_by_labels[labels_key] = s.value
assert set(buckets_by_labels.keys()) == set(
counts_by_labels.keys()) == set(sums_by_labels.keys())
output = []
label_keys = list(buckets_by_labels.keys())
for k in label_keys:
labels = dict(k)
output.append((labels, buckets_by_labels[k], counts_by_labels[k],
sums_by_labels[k]))
return output
def _digest_num_accepted_by_pos_samples(
samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]:
#
# In the case of DP, we have an indigestable
# per-position-per-engine count as a list of
# labelled samples
#
# samples (in):
# labels = {pos: 0, idx: 0}, value = 10
# labels = {pos: 1, idx: 0}, value = 7
# labels = {pos: 2, idx: 0}, value = 2
# labels = {pos: 0, idx: 1}, value = 5
# labels = {pos: 1, idx: 1}, value = 3
# labels = {pos: 2, idx: 1}, value = 1
#
# output: [
# {idx: 0}, [10, 7, 2]
# {idx: 1}, [5, 3, 1]
# ]
#
max_pos = 0
values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {}
for s in samples:
position = int(s.labels["position"])
max_pos = max(max_pos, position)
labels_key = frozenset(_strip_label(s.labels, "position").items())
if labels_key not in values_by_labels:
values_by_labels[labels_key] = {}
values_by_labels[labels_key][position] = int(s.value)
output = []
for labels_key, values_by_position in values_by_labels.items():
labels = dict(labels_key)
values = [0] * (max_pos + 1)
for pos, val in values_by_position.items():
values[pos] = val
output.append((labels, values))
return output
...@@ -77,6 +77,10 @@ class Request: ...@@ -77,6 +77,10 @@ class Request:
self.output_token_ids = ConstantList(self._output_token_ids) self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids) self.all_token_ids = ConstantList(self._all_token_ids)
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None: if request.mm_inputs is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment