"vscode:/vscode.git/clone" did not exist on "63d7972f13d1c5a9d9bd55b664017067a9abd451"
Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -120,7 +120,7 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.data}
return {f"{self.modality}s": self.get_all()}
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
......
......@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
......@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def _get_mm_max_tokens(
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None:
mm_counts = self.get_mm_limits()
......@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
return self._get_mm_num_tokens(mm_inputs)
......@@ -164,7 +164,7 @@ class MultiModalRegistry:
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)
......
......@@ -429,12 +429,12 @@ def group_mm_kwargs_by_modality(
if merge_by_field_config is not None:
logger.warning_once(
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13."
"is deprecated and will be removed in v0.14."
)
if multimodal_cpu_fields is not None:
logger.warning_once(
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13."
"is deprecated and will be removed in v0.14."
)
from vllm.multimodal.inputs import MultiModalKwargsItems
......
......@@ -283,8 +283,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
# Allow per-request override of video backend via kwargs.
# This enables users to specify a different backend than the
# global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
# --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
video_loader_backend = (
kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
)
self.kwargs = kwargs
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
......
......@@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
......@@ -126,21 +127,13 @@ class CpuPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
if attn_selector_config.use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
return AttentionBackendEnum.CPU_ATTN.get_path()
......@@ -325,10 +318,16 @@ class CpuPlatform(Platform):
# We need to find the location of PyTorch's libgomp
torch_pkg = os.path.dirname(torch.__file__)
site_root = os.path.dirname(torch_pkg)
torch_libs = os.path.join(site_root, "torch.libs")
pytorch_libgomp_so_candidates = glob.glob(
os.path.join(torch_libs, "libgomp-*.so*")
)
# Search both torch.libs and torch/lib - See: https://github.com/vllm-project/vllm/issues/30470
torch_libs_paths = [
os.path.join(site_root, "torch.libs"),
os.path.join(torch_pkg, "lib"),
]
pytorch_libgomp_so_candidates = []
for torch_libs in torch_libs_paths:
pytorch_libgomp_so_candidates.extend(
glob.glob(os.path.join(torch_libs, "libgomp*.so*"))
)
if pytorch_libgomp_so_candidates:
pytorch_libgomp_so = pytorch_libgomp_so_candidates[0]
if ld_preload_str:
......
......@@ -7,14 +7,13 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
import torch
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
......@@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
......@@ -182,7 +182,7 @@ class CudaPlatformBase(Platform):
if vllm_config.attention_config.backend is None:
# Default case
if cls.is_device_capability(100) and not use_sparse:
if cls.is_device_capability_family(100) and not use_sparse:
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
use_cutlass_mla = True
# Set the backend in AttentionConfig so it's used during
......@@ -255,36 +255,11 @@ class CudaPlatformBase(Platform):
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_valid_backends(
cls,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
......@@ -292,21 +267,15 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability)
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
......@@ -321,37 +290,19 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability()
assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
......@@ -367,16 +318,8 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
attn_selector_config=attn_selector_config,
)
reasons_str = (
"{"
......@@ -386,11 +329,7 @@ class CudaPlatformBase(Platform):
)
+ "}"
)
config_str = (
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
......@@ -409,14 +348,50 @@ class CudaPlatformBase(Platform):
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info(
logger.info_once(
"Using %s attention backend out of potential backends: %s",
selected_backend.name,
[b[0].name for b in valid_backends_priorities],
tuple(b[0].name for b in valid_backends_priorities),
scope="local",
)
return selected_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
......
......@@ -7,7 +7,7 @@ import platform
import random
import sys
from datetime import timedelta
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
import numpy as np
import torch
......@@ -18,8 +18,8 @@ from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
......@@ -222,29 +222,52 @@ class Platform:
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
"""Get the attention backend class of a device."""
return ""
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
"""
Get the vision attention backend class of a device.
NOTE: ViT Attention should be checked and override in the platform-specific
implementation. we should not override this in any other places, like
the model_executor/models/<model_name>.py.
We check if the backend is None or not:
1. If not, check if the backend is supported by the platform.
2. If None, continue to the default selection logic.
"""
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
)
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_device_capability(
cls,
......@@ -301,6 +324,21 @@ class Platform:
return current_capability.to_int() == capability
@classmethod
def is_device_capability_family(
cls,
capability: int,
device_id: int = 0,
) -> bool:
"""
Returns True if the device capability is any <major>.x.
Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
return (current_capability.to_int() // 10) == (capability // 10)
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
......
......@@ -3,7 +3,7 @@
import os
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
......@@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
logger = init_logger(__name__)
......@@ -188,42 +189,19 @@ class RocmPlatform(Platform):
if not on_gfx9():
supported_quantization += ["bitsandbytes"]
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
......@@ -233,7 +211,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla:
if attn_selector_config.use_mla:
if selected_backend is None:
selected_backend = (
# AttentionBackendEnum.ROCM_AITER_MLA
......@@ -324,6 +302,43 @@ class RocmPlatform(Platform):
"ROCm. Note that V0 attention backends have been removed."
)
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
......@@ -405,7 +420,21 @@ class RocmPlatform(Platform):
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
if (
envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
# NOTE: This block has been deprecated
# or get_env_variable_attn_backend()
# == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
# TODO: monitor https://github.com/vllm-project/vllm/pull/30396
# to see how we can transition to the new way of selecting
# attention backends
):
cache_config.block_size = 64
logger.warning(
"[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
)
else:
cache_config.block_size = 16
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast
import torch
from tpu_info import device
......@@ -17,6 +17,7 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
......@@ -58,17 +59,9 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
......@@ -76,6 +69,32 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
......
......@@ -3,7 +3,7 @@
import contextlib
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
......@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
......@@ -42,15 +43,7 @@ class XPUPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
......@@ -60,7 +53,7 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
......@@ -71,12 +64,40 @@ class XPUPlatform(Platform):
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_mla: {use_mla}"
f"with use_mla: {attn_selector_config.use_mla}"
)
logger.info("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
# XPU only supports FLASH_ATTN for vision attention.
return [
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: "
f"{cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
)
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
......@@ -110,12 +131,6 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def inference_mode(cls):
return torch.no_grad()
......
......@@ -61,7 +61,7 @@ class WorkerProfiler(ABC):
"""Call _stop with error handling but no safeguards."""
try:
self._stop()
logger.info("Profiler stopped successfully.")
logger.info_once("Profiler stopped successfully.", scope="local")
except Exception as e:
logger.warning("Failed to stop profiler: %s", e)
self._running = False # Always mark as not running, assume stop worked
......@@ -91,7 +91,7 @@ class WorkerProfiler(ABC):
and self._delay_iters > 0
and self._active_iteration_count == self._delay_iters
):
logger.info("Starting profiler after delay...")
logger.info_once("Starting profiler after delay...", scope="local")
self._call_start()
if self._running:
......@@ -105,7 +105,9 @@ class WorkerProfiler(ABC):
# Automatically stop the profiler after max iters
# will be marked as not running, but leave as active so that stop
# can clean up properly
logger.info("Max profiling iterations reached. Stopping profiler...")
logger.info_once(
"Max profiling iterations reached. Stopping profiler...", scope="local"
)
self._call_stop()
return
......@@ -125,7 +127,7 @@ class WorkerProfiler(ABC):
def shutdown(self) -> None:
"""Ensure profiler is stopped when shutting down."""
logger.info_once("Shutting down profiler")
logger.info_once("Shutting down profiler", scope="local")
if self._running:
self.stop()
......@@ -156,9 +158,10 @@ class TorchProfilerWrapper(WorkerProfiler):
self.profiler_config = profiler_config
torch_profiler_trace_dir = profiler_config.torch_profiler_dir
if local_rank in (None, 0):
logger.info(
logger.info_once(
"Torch profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
scope="local",
)
logger.debug(
"Profiler config: record_shapes=%s,"
......
......@@ -19,6 +19,10 @@ logger = init_logger(__name__)
class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for MiniMax M2 model.
MiniMax M2 models don't generate <think> start token, only </think> end
token. All content before </think> is reasoning, content after is the
actual response.
"""
@property
......@@ -31,6 +35,45 @@ class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
"""The token that ends reasoning content."""
return "</think>"
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extract reasoning content from a delta message for streaming.
MiniMax M2 models don't generate <think> start token, so we assume
all content is reasoning until we encounter the </think> end token.
"""
# Skip single end token
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id:
return None
# Check if end token has already appeared in previous tokens
# meaning we're past the reasoning phase
if self.end_token_id in previous_token_ids:
# We're past the reasoning phase, this is content
return DeltaMessage(content=delta_text)
# Check if end token is in delta tokens
if self.end_token_id in delta_token_ids:
# End token in delta, split reasoning and content
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning if reasoning else None,
content=content if content else None,
)
# No end token yet, all content is reasoning
return DeltaMessage(reasoning=delta_text)
class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
"""
......
......@@ -3,20 +3,29 @@
from functools import cached_property
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ResponsesRequest,
)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.tokenizers import MistralTokenizer
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__)
class MistralReasoningParser(DeepSeekR1ReasoningParser):
class MistralReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Mistral models.
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning
The Mistral models uses `[THINK]`...`[/THINK]` tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
A valid reasoning trace should always start with a `[THINK]` token and end with
a `[/THINK]` token.
If `[THINK]` token is not generated, then this parser only returns content.
"""
def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
......@@ -53,3 +62,93 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
from mistral_common.tokens.tokenizers.base import SpecialTokens
return SpecialTokens.end_think
def is_reasoning_end(self, input_ids: list[int]) -> bool:
has_eot_token = False
for id in input_ids[::-1]:
if id == self.start_token_id:
# Reasoning ends only if a BOT token is found before a EOT token.
return has_eot_token
elif id == self.end_token_id:
has_eot_token = True
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content
"""
has_bot_token = False
has_eot_token = False
bot_token_index = -1
eot_token_index = -1
# One for loop instead of multiple lookups
for i, token_id in enumerate(input_ids):
# We filter that we have multiple BOT tokens which should not
# happen for a well prompted trained model
if token_id == self.start_token_id and not has_bot_token:
has_bot_token = True
bot_token_index = i
elif token_id == self.end_token_id:
has_eot_token = True
eot_token_index = i
break
# 1. Only BOT has been outputted
if has_bot_token and not has_eot_token:
# Should be = [] if model is well prompted and trained.
return input_ids[:bot_token_index]
# 2. Neither BOT or EOT have been outputted
elif not has_bot_token and not has_eot_token:
return input_ids
# 3. Both BOT and EOT have been outputted.
elif has_bot_token and has_eot_token:
return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :]
# 4. Only EOT has been outputted => this should not have occured for a model
# well prompted and trained.
else:
return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
) -> tuple[str | None, str | None]:
"""
Extract reasoning content from the model output.
"""
if not model_output:
return (None, "")
# Check if the start token is present in the model output, remove it
# if it is present.
prev_bot_token, bot_token, post_bot_token = model_output.partition(
self.start_token
)
has_bot_token = bool(bot_token)
# Valid EOT tokens should follow BOT token
has_valid_eot_token = has_bot_token and self.end_token in post_bot_token
# 1. If there is BOT token followed by EOT token
if has_bot_token and has_valid_eot_token:
prev_eot_token, _, post_eot_token = post_bot_token.partition(self.end_token)
# If model is well prompted and trained prev_bot_token should be ""
content = prev_bot_token + post_eot_token
return prev_eot_token, content if content else None
# 2. Only BOT token
elif has_bot_token:
# If model is well prompted and trained prev_bot_token should be ""
return post_bot_token, prev_bot_token if prev_bot_token else None
# 3. EOT token has been outputted without BOT or neither has been outputted
else:
has_non_valid_eot_token = self.end_token in prev_bot_token
# 3.a EOT token has been outputted without BOT
# If model is well prompted and trained `has_non_valid_eot_token` should
# be `False` and the parser outputs all tokens as 'content'
if has_non_valid_eot_token:
prev_eot_token, _, post_eot_token = prev_bot_token.partition(
self.end_token
)
return None, prev_eot_token + post_eot_token
# 3.b neither BOT or EOT have been outputted
else:
return None, prev_bot_token
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .deepseekv32 import DeepseekV32Tokenizer
from .hf import HfTokenizer
from .mistral import MistralTokenizer
from .protocol import TokenizerLike
from .registry import (
TokenizerRegistry,
......@@ -15,12 +12,9 @@ from .registry import (
__all__ = [
"TokenizerLike",
"HfTokenizer",
"MistralTokenizer",
"TokenizerRegistry",
"cached_get_tokenizer",
"get_tokenizer",
"cached_tokenizer_from_config",
"init_tokenizer_from_config",
"DeepseekV32Tokenizer",
]
......@@ -2,22 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import Any
from transformers import BatchEncoding
from .deepseek_v32_encoding import encode_messages
from .hf import HfTokenizer, TokenizerLike
from .registry import TokenizerRegistry
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from .deepseek_v32_encoding import encode_messages
from .hf import CachedHfTokenizer
from .protocol import TokenizerLike
@TokenizerRegistry.register("deepseek_v32")
class DeepseekV32Tokenizer(HfTokenizer):
def __init__(self, tokenizer: TokenizerLike):
self.tokenizer = tokenizer
self.name_or_path = (
tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else ""
)
class DeepseekV32Tokenizer(CachedHfTokenizer):
@classmethod
def from_pretrained(
cls,
......@@ -38,20 +34,47 @@ class DeepseekV32Tokenizer(HfTokenizer):
)
return DeepseekV32Tokenizer(tokenizer)
def apply_chat_template(self, messages, tools=None, **kwargs):
def __init__(self, tokenizer: TokenizerLike) -> None:
super().__init__()
self.tokenizer = tokenizer
self.name_or_path = getattr(tokenizer, "name_or_path", "")
self._added_vocab = self.tokenizer.get_added_vocab()
self._added_vocab_size = len(self._added_vocab)
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> str | list[int]:
thinking = kwargs.get("thinking", False)
thinking_mode = "thinking"
if not thinking:
thinking_mode = "chat"
conversation = kwargs.get("conversation", messages)
messages = conversation.copy()
drop_thinking = True
if tools is not None and len(tools) > 0:
messages.insert(0, {"role": "system"})
messages[0]["tools"] = tools
drop_thinking = False
messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key]
# Historical reasoning content is dropped when a new user message is introduced
drop_thinking = messages[-1]["role"] == "user"
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
prompt_str = encode_messages(messages, **encode_config) # type: ignore
if kwargs.get("tokenize", True):
tokenizer_kwargs = {
k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
}
return self.encode(
prompt_str,
add_special_tokens=False,
**tokenizer_kwargs,
)
return prompt_str
def num_special_tokens_to_add(self) -> int:
......@@ -98,7 +121,7 @@ class DeepseekV32Tokenizer(HfTokenizer):
def __len__(self) -> int:
# </think> is an added token in DeepseekV32 tokenizer
return self.vocab_size + len(self.get_added_vocab())
return self.vocab_size + self._added_vocab_size
def __call__(
self,
......@@ -120,7 +143,7 @@ class DeepseekV32Tokenizer(HfTokenizer):
return self.tokenizer.get_vocab()
def get_added_vocab(self) -> dict[str, int]:
return self.tokenizer.get_added_vocab()
return self._added_vocab.copy()
def encode(
self,
......
......@@ -3,22 +3,18 @@
import contextlib
import copy
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TypeAlias
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from .protocol import TokenizerLike
from .registry import TokenizerRegistry
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
def get_cached_tokenizer(
tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast",
) -> TokenizerLike:
def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown.
......@@ -65,11 +61,10 @@ def get_cached_tokenizer(
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer # type: ignore
return cached_tokenizer
@TokenizerRegistry.register("hf")
class HfTokenizer(TokenizerLike):
class CachedHfTokenizer(TokenizerLike):
@classmethod
def from_pretrained(
cls,
......@@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike):
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "TokenizerLike":
) -> HfTokenizer:
try:
tokenizer = AutoTokenizer.from_pretrained(
path_or_repo_id,
......
......@@ -3,10 +3,11 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.logger import init_logger
from .protocol import TokenizerLike
from .registry import TokenizerRegistry
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import (
......@@ -15,9 +16,6 @@ if TYPE_CHECKING:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers import BatchEncoding
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
......@@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
return tokenizer.unk_id
@TokenizerRegistry.register("mistral")
class MistralTokenizer(TokenizerLike):
@classmethod
def from_pretrained(
......
......@@ -97,7 +97,7 @@ class TokenizerLike(Protocol):
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
) -> str | list[int]:
raise NotImplementedError
def convert_tokens_to_string(self, tokens: list[str]) -> str:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, overload
from typing import TYPE_CHECKING
import huggingface_hub
from typing_extensions import assert_never
from typing_extensions import TypeVar, assert_never, deprecated
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config.model import ModelConfig, RunnerType
logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[TokenizerLike])
_VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"mistral": ("mistral", "MistralTokenizer"),
}
class TokenizerRegistry:
# Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class)
REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {}
# In-tree tokenizers
@staticmethod
@overload
def register(tokenizer_mode: str) -> Callable[[_T], _T]: ...
@dataclass
class _TokenizerRegistry:
# Tokenizer mode -> (tokenizer module, tokenizer class)
tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict)
# OOT tokenizers
@staticmethod
@overload
def register(tokenizer_mode: str, module: str, class_name: str) -> None: ...
@staticmethod
def register(
tokenizer_mode: str,
module: str | None = None,
class_name: str | None = None,
) -> Callable[[_T], _T] | None:
# In-tree tokenizers
if module is None or class_name is None:
def wrapper(tokenizer_cls: _T) -> _T:
assert tokenizer_mode not in TokenizerRegistry.REGISTRY
TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls
return tokenizer_cls
return wrapper
# OOT tokenizers
if tokenizer_mode in TokenizerRegistry.REGISTRY:
def register(self, tokenizer_mode: str, module: str, class_name: str) -> None:
if tokenizer_mode in self.tokenizers:
logger.warning(
"%s.%s is already registered for tokenizer_mode=%r. "
"It is overwritten by the new one.",
......@@ -72,36 +51,42 @@ class TokenizerRegistry:
tokenizer_mode,
)
TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name)
self.tokenizers[tokenizer_mode] = (module, class_name)
return None
@staticmethod
def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike":
if tokenizer_mode not in TokenizerRegistry.REGISTRY:
def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]:
if tokenizer_mode not in self.tokenizers:
raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")
item = TokenizerRegistry.REGISTRY[tokenizer_mode]
if isinstance(item, type):
return item.from_pretrained(*args, **kwargs)
module, class_name = item
module, class_name = self.tokenizers[tokenizer_mode]
logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
class_ = resolve_obj_by_qualname(f"{module}.{class_name}")
return class_.from_pretrained(*args, **kwargs)
return resolve_obj_by_qualname(f"{module}.{class_name}")
def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike:
tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode)
return tokenizer_cls.from_pretrained(*args, **kwargs)
def get_tokenizer(
TokenizerRegistry = _TokenizerRegistry(
{
mode: (f"vllm.tokenizers.{mod_relname}", cls_name)
for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items()
}
)
def resolve_tokenizer_args(
tokenizer_name: str | Path,
*args,
runner_type: "RunnerType" = "generate",
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> TokenizerLike:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
):
revision: str | None = kwargs.get("revision")
download_dir: str | None = kwargs.get("download_dir")
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
......@@ -125,16 +110,6 @@ def get_tokenizer(
)
tokenizer_name = tokenizer_path
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
tokenizer_mode = "hf"
kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models
if is_gguf(tokenizer_name):
if check_gguf_file(tokenizer_name):
......@@ -150,6 +125,21 @@ def get_tokenizer(
)
kwargs["gguf_file"] = gguf_file
if "truncation_side" not in kwargs:
if runner_type == "generate" or runner_type == "draft":
kwargs["truncation_side"] = "left"
elif runner_type == "pooling":
kwargs["truncation_side"] = "right"
else:
assert_never(runner_type)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
tokenizer_mode = "hf"
kwargs["use_fast"] = False
# Try to use official Mistral tokenizer if possible
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
......@@ -165,49 +155,70 @@ def get_tokenizer(
if tokenizer_mode == "auto":
tokenizer_mode = "hf"
tokenizer_args = (tokenizer_name, *args)
tokenizer_kwargs = dict(
return tokenizer_mode, tokenizer_name, args, kwargs
cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args)
def tokenizer_args_from_config(config: "ModelConfig", **kwargs):
return cached_resolve_tokenizer_args(
config.tokenizer,
runner_type=config.runner_type,
tokenizer_mode=config.tokenizer_mode,
revision=config.tokenizer_revision,
trust_remote_code=config.trust_remote_code,
**kwargs,
)
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
def get_tokenizer(
tokenizer_name: str | Path,
*args,
tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment]
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> _T:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
download_dir=download_dir,
**kwargs,
)
if tokenizer_mode == "custom":
logger.warning_once(
"TokenizerRegistry now uses `tokenizer_mode` as the registry key "
"instead of `tokenizer_name`. "
"Please update the definition of `.from_pretrained` in "
"your custom tokenizer to accept `args=%s`, `kwargs=%s`. "
"Then, you can pass `tokenizer_mode=%r` instead of "
"`tokenizer_mode='custom'` when initializing vLLM.",
tokenizer_args,
str(tokenizer_kwargs),
tokenizer_name,
)
tokenizer_mode = str(tokenizer_name)
if tokenizer_cls == TokenizerLike:
tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
else:
tokenizer_cls_ = tokenizer_cls
tokenizer = TokenizerRegistry.get_tokenizer(
tokenizer_mode,
*tokenizer_args,
**tokenizer_kwargs,
)
tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs)
if not tokenizer.is_fast:
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
return tokenizer
return tokenizer # type: ignore
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
if model_config.skip_tokenizer_init:
return None
return cached_get_tokenizer(
model_config.tokenizer,
runner_type=model_config.runner_type,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
......@@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
)
@deprecated(
"Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14."
)
def init_tokenizer_from_config(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
return cached_tokenizer_from_config(model_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment