Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
...@@ -120,7 +120,7 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): ...@@ -120,7 +120,7 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
return self.data[index] return self.data[index]
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.data} return {f"{self.modality}s": self.get_all()}
def get_passthrough_data(self) -> Mapping[str, object]: def get_passthrough_data(self) -> Mapping[str, object]:
return {} return {}
......
...@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]): ...@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens( def _get_mm_num_tokens(
self, self,
mm_inputs: MultiModalInputs, mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
return { return {
modality: sum( modality: sum(item.get_num_embeds for item in placeholders)
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
...@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]): ...@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"], multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )
def _get_mm_max_tokens( def get_mm_max_tokens(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None: if mm_counts is None:
mm_counts = self.get_mm_limits() mm_counts = self.get_mm_limits()
...@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]): ...@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
} }
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) return self._get_mm_num_tokens(mm_inputs)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
...@@ -164,7 +164,7 @@ class MultiModalRegistry: ...@@ -164,7 +164,7 @@ class MultiModalRegistry:
profiler.get_mm_limits() if profiler_limits is None else profiler_limits profiler.get_mm_limits() if profiler_limits is None else profiler_limits
) )
return profiler.get_mm_max_contiguous_tokens( return profiler.get_mm_max_tokens(
seq_len, seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, {modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
) )
......
...@@ -429,12 +429,12 @@ def group_mm_kwargs_by_modality( ...@@ -429,12 +429,12 @@ def group_mm_kwargs_by_modality(
if merge_by_field_config is not None: if merge_by_field_config is not None:
logger.warning_once( logger.warning_once(
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` " "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13." "is deprecated and will be removed in v0.14."
) )
if multimodal_cpu_fields is not None: if multimodal_cpu_fields is not None:
logger.warning_once( logger.warning_once(
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` " "The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13." "is deprecated and will be removed in v0.14."
) )
from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalKwargsItems
......
...@@ -283,8 +283,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): ...@@ -283,8 +283,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
# They can be passed to the underlying # They can be passed to the underlying
# media loaders (e.g. custom implementations) # media loaders (e.g. custom implementations)
# for flexible control. # for flexible control.
# Allow per-request override of video backend via kwargs.
# This enables users to specify a different backend than the
# global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
# --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
video_loader_backend = (
kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
)
self.kwargs = kwargs self.kwargs = kwargs
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]: def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
......
...@@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum ...@@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
...@@ -126,21 +127,13 @@ class CpuPlatform(Platform): ...@@ -126,21 +127,13 @@ class CpuPlatform(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if attn_selector_config.use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.") raise NotImplementedError("Sparse Attention is not supported on CPU.")
return AttentionBackendEnum.CPU_ATTN.get_path() return AttentionBackendEnum.CPU_ATTN.get_path()
...@@ -325,10 +318,16 @@ class CpuPlatform(Platform): ...@@ -325,10 +318,16 @@ class CpuPlatform(Platform):
# We need to find the location of PyTorch's libgomp # We need to find the location of PyTorch's libgomp
torch_pkg = os.path.dirname(torch.__file__) torch_pkg = os.path.dirname(torch.__file__)
site_root = os.path.dirname(torch_pkg) site_root = os.path.dirname(torch_pkg)
torch_libs = os.path.join(site_root, "torch.libs") # Search both torch.libs and torch/lib - See: https://github.com/vllm-project/vllm/issues/30470
pytorch_libgomp_so_candidates = glob.glob( torch_libs_paths = [
os.path.join(torch_libs, "libgomp-*.so*") os.path.join(site_root, "torch.libs"),
) os.path.join(torch_pkg, "lib"),
]
pytorch_libgomp_so_candidates = []
for torch_libs in torch_libs_paths:
pytorch_libgomp_so_candidates.extend(
glob.glob(os.path.join(torch_libs, "libgomp*.so*"))
)
if pytorch_libgomp_so_candidates: if pytorch_libgomp_so_candidates:
pytorch_libgomp_so = pytorch_libgomp_so_candidates[0] pytorch_libgomp_so = pytorch_libgomp_so_candidates[0]
if ld_preload_str: if ld_preload_str:
......
...@@ -7,14 +7,13 @@ pynvml. However, it should not initialize cuda context. ...@@ -7,14 +7,13 @@ pynvml. However, it should not initialize cuda context.
import os import os
from collections.abc import Callable from collections.abc import Callable
from functools import cache, wraps from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, Optional, TypeVar
import torch import torch
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
...@@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
else: else:
...@@ -182,7 +182,7 @@ class CudaPlatformBase(Platform): ...@@ -182,7 +182,7 @@ class CudaPlatformBase(Platform):
if vllm_config.attention_config.backend is None: if vllm_config.attention_config.backend is None:
# Default case # Default case
if cls.is_device_capability(100) and not use_sparse: if cls.is_device_capability_family(100) and not use_sparse:
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2). # Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
use_cutlass_mla = True use_cutlass_mla = True
# Set the backend in AttentionConfig so it's used during # Set the backend in AttentionConfig so it's used during
...@@ -255,36 +255,11 @@ class CudaPlatformBase(Platform): ...@@ -255,36 +255,11 @@ class CudaPlatformBase(Platform):
torch.cuda.reset_peak_memory_stats(device) torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device) return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_valid_backends( def get_valid_backends(
cls, cls,
head_size, device_capability: DeviceCapability,
dtype, attn_selector_config: "AttentionSelectorConfig",
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]], dict["AttentionBackendEnum", list[str]],
...@@ -292,21 +267,15 @@ class CudaPlatformBase(Platform): ...@@ -292,21 +267,15 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = [] valid_backends_priorities = []
invalid_reasons = {} invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability) backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
for priority, backend in enumerate(backend_priorities): for priority, backend in enumerate(backend_priorities):
try: try:
backend_class = backend.get_class() backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration( invalid_reasons_i = backend_class.validate_configuration(
head_size, device_capability=device_capability,
dtype, **attn_selector_config._asdict(),
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
...@@ -321,37 +290,19 @@ class CudaPlatformBase(Platform): ...@@ -321,37 +290,19 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
assert device_capability is not None assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one. # First try checking just the selected backend, if there is one.
if selected_backend is not None: if selected_backend is not None:
try: try:
backend_class = selected_backend.get_class() backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration( invalid_reasons = backend_class.validate_configuration(
head_size, device_capability=device_capability,
dtype, **attn_selector_config._asdict(),
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons = ["ImportError"] invalid_reasons = ["ImportError"]
...@@ -367,16 +318,8 @@ class CudaPlatformBase(Platform): ...@@ -367,16 +318,8 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid, # No selected backend or the selected backend is invalid,
# so we try finding a valid backend. # so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends( valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size, device_capability=device_capability,
dtype, attn_selector_config=attn_selector_config,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) )
reasons_str = ( reasons_str = (
"{" "{"
...@@ -386,11 +329,7 @@ class CudaPlatformBase(Platform): ...@@ -386,11 +329,7 @@ class CudaPlatformBase(Platform):
) )
+ "}" + "}"
) )
config_str = ( config_str = attn_selector_config.__repr__()
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
logger.debug_once( logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with " f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}." f"{config_str}. Reasons: {reasons_str}."
...@@ -409,14 +348,50 @@ class CudaPlatformBase(Platform): ...@@ -409,14 +348,50 @@ class CudaPlatformBase(Platform):
) )
selected_index = sorted_indices[0] selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0] selected_backend = valid_backends_priorities[selected_index][0]
logger.info( logger.info_once(
"Using %s attention backend out of potential backends: %s", "Using %s attention backend out of potential backends: %s",
selected_backend.name, selected_backend.name,
[b[0].name for b in valid_backends_priorities], tuple(b[0].name for b in valid_backends_priorities),
scope="local",
) )
return selected_backend.get_path() return selected_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
# Try FlashAttention first
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
......
...@@ -7,7 +7,7 @@ import platform ...@@ -7,7 +7,7 @@ import platform
import random import random
import sys import sys
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, NamedTuple from typing import TYPE_CHECKING, Any, NamedTuple, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -18,8 +18,8 @@ from vllm.logger import init_logger ...@@ -18,8 +18,8 @@ from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -222,29 +222,52 @@ class Platform: ...@@ -222,29 +222,52 @@ class Platform:
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
"""
Get the vision attention backend class of a device.
NOTE: ViT Attention should be checked and override in the platform-specific
implementation. we should not override this in any other places, like
the model_executor/models/<model_name>.py.
We check if the backend is None or not:
1. If not, check if the backend is supported by the platform.
2. If None, continue to the default selection logic.
"""
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention"
)
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_device_capability( def get_device_capability(
cls, cls,
...@@ -301,6 +324,21 @@ class Platform: ...@@ -301,6 +324,21 @@ class Platform:
return current_capability.to_int() == capability return current_capability.to_int() == capability
@classmethod
def is_device_capability_family(
cls,
capability: int,
device_id: int = 0,
) -> bool:
"""
Returns True if the device capability is any <major>.x.
Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x).
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
return (current_capability.to_int() // 10) == (capability // 10)
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device.""" """Get the name of a device."""
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -188,42 +189,19 @@ class RocmPlatform(Platform): ...@@ -188,42 +189,19 @@ class RocmPlatform(Platform):
if not on_gfx9(): if not on_gfx9():
supported_quantization += ["bitsandbytes"] supported_quantization += ["bitsandbytes"]
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend, selected_backend: "AttentionBackendEnum",
head_size, attn_selector_config: "AttentionSelectorConfig",
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
if use_sparse: block_size = attn_selector_config.block_size
if kv_cache_dtype.startswith("fp8"): kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError( raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
) )
...@@ -233,7 +211,7 @@ class RocmPlatform(Platform): ...@@ -233,7 +211,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Sparse MLA backend on V1 engine.") logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla: if attn_selector_config.use_mla:
if selected_backend is None: if selected_backend is None:
selected_backend = ( selected_backend = (
# AttentionBackendEnum.ROCM_AITER_MLA # AttentionBackendEnum.ROCM_AITER_MLA
...@@ -324,6 +302,43 @@ class RocmPlatform(Platform): ...@@ -324,6 +302,43 @@ class RocmPlatform(Platform):
"ROCm. Note that V0 attention backends have been removed." "ROCm. Note that V0 attention backends have been removed."
) )
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return AttentionBackendEnum.FLASH_ATTN
return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
""" """
...@@ -405,7 +420,21 @@ class RocmPlatform(Platform): ...@@ -405,7 +420,21 @@ class RocmPlatform(Platform):
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 if (
envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
# NOTE: This block has been deprecated
# or get_env_variable_attn_backend()
# == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN
# TODO: monitor https://github.com/vllm-project/vllm/pull/30396
# to see how we can transition to the new way of selecting
# attention backends
):
cache_config.block_size = 64
logger.warning(
"[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64."
)
else:
cache_config.block_size = 16
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, Optional, cast
import torch import torch
from tpu_info import device from tpu_info import device
...@@ -17,6 +17,7 @@ from .interface import Platform, PlatformEnum ...@@ -17,6 +17,7 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeAlias from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import BlockSize from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -58,17 +59,9 @@ class TpuPlatform(Platform): ...@@ -58,17 +59,9 @@ class TpuPlatform(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
if use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.") raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS: if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
...@@ -76,6 +69,32 @@ class TpuPlatform(Platform): ...@@ -76,6 +69,32 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.") logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path() return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
""" """
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import contextlib import contextlib
import os import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -14,6 +14,7 @@ from vllm.logger import init_logger ...@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
...@@ -42,15 +43,7 @@ class XPUPlatform(Platform): ...@@ -42,15 +43,7 @@ class XPUPlatform(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.attention.backends.utils import set_kv_cache_layout
...@@ -60,7 +53,7 @@ class XPUPlatform(Platform): ...@@ -60,7 +53,7 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels." "only NHD layout is supported by XPU attention kernels."
) )
if use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN: if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.") logger.info_once("Using Triton backend.")
...@@ -71,12 +64,40 @@ class XPUPlatform(Platform): ...@@ -71,12 +64,40 @@ class XPUPlatform(Platform):
elif selected_backend: elif selected_backend:
raise ValueError( raise ValueError(
f"Invalid attention backend for {cls.device_name}, " f"Invalid attention backend for {cls.device_name}, "
f"with use_mla: {use_mla}" f"with use_mla: {attn_selector_config.use_mla}"
) )
logger.info("Using Flash Attention backend.") logger.info("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path() return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
# XPU only supports FLASH_ATTN for vision attention.
return [
AttentionBackendEnum.FLASH_ATTN,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: "
f"{cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
logger.info_once(
f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
)
return AttentionBackendEnum.FLASH_ATTN
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
""" """
...@@ -110,12 +131,6 @@ class XPUPlatform(Platform): ...@@ -110,12 +131,6 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id) device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory return device_props.total_memory
@classmethod
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
return AttentionBackendEnum.FLASH_ATTN
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
......
...@@ -61,7 +61,7 @@ class WorkerProfiler(ABC): ...@@ -61,7 +61,7 @@ class WorkerProfiler(ABC):
"""Call _stop with error handling but no safeguards.""" """Call _stop with error handling but no safeguards."""
try: try:
self._stop() self._stop()
logger.info("Profiler stopped successfully.") logger.info_once("Profiler stopped successfully.", scope="local")
except Exception as e: except Exception as e:
logger.warning("Failed to stop profiler: %s", e) logger.warning("Failed to stop profiler: %s", e)
self._running = False # Always mark as not running, assume stop worked self._running = False # Always mark as not running, assume stop worked
...@@ -91,7 +91,7 @@ class WorkerProfiler(ABC): ...@@ -91,7 +91,7 @@ class WorkerProfiler(ABC):
and self._delay_iters > 0 and self._delay_iters > 0
and self._active_iteration_count == self._delay_iters and self._active_iteration_count == self._delay_iters
): ):
logger.info("Starting profiler after delay...") logger.info_once("Starting profiler after delay...", scope="local")
self._call_start() self._call_start()
if self._running: if self._running:
...@@ -105,7 +105,9 @@ class WorkerProfiler(ABC): ...@@ -105,7 +105,9 @@ class WorkerProfiler(ABC):
# Automatically stop the profiler after max iters # Automatically stop the profiler after max iters
# will be marked as not running, but leave as active so that stop # will be marked as not running, but leave as active so that stop
# can clean up properly # can clean up properly
logger.info("Max profiling iterations reached. Stopping profiler...") logger.info_once(
"Max profiling iterations reached. Stopping profiler...", scope="local"
)
self._call_stop() self._call_stop()
return return
...@@ -125,7 +127,7 @@ class WorkerProfiler(ABC): ...@@ -125,7 +127,7 @@ class WorkerProfiler(ABC):
def shutdown(self) -> None: def shutdown(self) -> None:
"""Ensure profiler is stopped when shutting down.""" """Ensure profiler is stopped when shutting down."""
logger.info_once("Shutting down profiler") logger.info_once("Shutting down profiler", scope="local")
if self._running: if self._running:
self.stop() self.stop()
...@@ -156,9 +158,10 @@ class TorchProfilerWrapper(WorkerProfiler): ...@@ -156,9 +158,10 @@ class TorchProfilerWrapper(WorkerProfiler):
self.profiler_config = profiler_config self.profiler_config = profiler_config
torch_profiler_trace_dir = profiler_config.torch_profiler_dir torch_profiler_trace_dir = profiler_config.torch_profiler_dir
if local_rank in (None, 0): if local_rank in (None, 0):
logger.info( logger.info_once(
"Torch profiling enabled. Traces will be saved to: %s", "Torch profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir, torch_profiler_trace_dir,
scope="local",
) )
logger.debug( logger.debug(
"Profiler config: record_shapes=%s," "Profiler config: record_shapes=%s,"
......
...@@ -19,6 +19,10 @@ logger = init_logger(__name__) ...@@ -19,6 +19,10 @@ logger = init_logger(__name__)
class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
""" """
Reasoning parser for MiniMax M2 model. Reasoning parser for MiniMax M2 model.
MiniMax M2 models don't generate <think> start token, only </think> end
token. All content before </think> is reasoning, content after is the
actual response.
""" """
@property @property
...@@ -31,6 +35,45 @@ class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): ...@@ -31,6 +35,45 @@ class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser):
"""The token that ends reasoning content.""" """The token that ends reasoning content."""
return "</think>" return "</think>"
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extract reasoning content from a delta message for streaming.
MiniMax M2 models don't generate <think> start token, so we assume
all content is reasoning until we encounter the </think> end token.
"""
# Skip single end token
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id:
return None
# Check if end token has already appeared in previous tokens
# meaning we're past the reasoning phase
if self.end_token_id in previous_token_ids:
# We're past the reasoning phase, this is content
return DeltaMessage(content=delta_text)
# Check if end token is in delta tokens
if self.end_token_id in delta_token_ids:
# End token in delta, split reasoning and content
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning if reasoning else None,
content=content if content else None,
)
# No end token yet, all content is reasoning
return DeltaMessage(reasoning=delta_text)
class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
""" """
......
...@@ -3,20 +3,29 @@ ...@@ -3,20 +3,29 @@
from functools import cached_property from functools import cached_property
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ResponsesRequest,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
class MistralReasoningParser(DeepSeekR1ReasoningParser): class MistralReasoningParser(BaseThinkingReasoningParser):
""" """
Reasoning parser for Mistral models. Reasoning parser for Mistral models.
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning The Mistral models uses `[THINK]`...`[/THINK]` tokens to denote reasoning
text. This parser extracts the reasoning content from the model output. text. This parser extracts the reasoning content from the model output.
A valid reasoning trace should always start with a `[THINK]` token and end with
a `[/THINK]` token.
If `[THINK]` token is not generated, then this parser only returns content.
""" """
def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs): def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
...@@ -53,3 +62,93 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): ...@@ -53,3 +62,93 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokens
return SpecialTokens.end_think return SpecialTokens.end_think
def is_reasoning_end(self, input_ids: list[int]) -> bool:
has_eot_token = False
for id in input_ids[::-1]:
if id == self.start_token_id:
# Reasoning ends only if a BOT token is found before a EOT token.
return has_eot_token
elif id == self.end_token_id:
has_eot_token = True
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content
"""
has_bot_token = False
has_eot_token = False
bot_token_index = -1
eot_token_index = -1
# One for loop instead of multiple lookups
for i, token_id in enumerate(input_ids):
# We filter that we have multiple BOT tokens which should not
# happen for a well prompted trained model
if token_id == self.start_token_id and not has_bot_token:
has_bot_token = True
bot_token_index = i
elif token_id == self.end_token_id:
has_eot_token = True
eot_token_index = i
break
# 1. Only BOT has been outputted
if has_bot_token and not has_eot_token:
# Should be = [] if model is well prompted and trained.
return input_ids[:bot_token_index]
# 2. Neither BOT or EOT have been outputted
elif not has_bot_token and not has_eot_token:
return input_ids
# 3. Both BOT and EOT have been outputted.
elif has_bot_token and has_eot_token:
return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :]
# 4. Only EOT has been outputted => this should not have occured for a model
# well prompted and trained.
else:
return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
) -> tuple[str | None, str | None]:
"""
Extract reasoning content from the model output.
"""
if not model_output:
return (None, "")
# Check if the start token is present in the model output, remove it
# if it is present.
prev_bot_token, bot_token, post_bot_token = model_output.partition(
self.start_token
)
has_bot_token = bool(bot_token)
# Valid EOT tokens should follow BOT token
has_valid_eot_token = has_bot_token and self.end_token in post_bot_token
# 1. If there is BOT token followed by EOT token
if has_bot_token and has_valid_eot_token:
prev_eot_token, _, post_eot_token = post_bot_token.partition(self.end_token)
# If model is well prompted and trained prev_bot_token should be ""
content = prev_bot_token + post_eot_token
return prev_eot_token, content if content else None
# 2. Only BOT token
elif has_bot_token:
# If model is well prompted and trained prev_bot_token should be ""
return post_bot_token, prev_bot_token if prev_bot_token else None
# 3. EOT token has been outputted without BOT or neither has been outputted
else:
has_non_valid_eot_token = self.end_token in prev_bot_token
# 3.a EOT token has been outputted without BOT
# If model is well prompted and trained `has_non_valid_eot_token` should
# be `False` and the parser outputs all tokens as 'content'
if has_non_valid_eot_token:
prev_eot_token, _, post_eot_token = prev_bot_token.partition(
self.end_token
)
return None, prev_eot_token + post_eot_token
# 3.b neither BOT or EOT have been outputted
else:
return None, prev_bot_token
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .deepseekv32 import DeepseekV32Tokenizer
from .hf import HfTokenizer
from .mistral import MistralTokenizer
from .protocol import TokenizerLike from .protocol import TokenizerLike
from .registry import ( from .registry import (
TokenizerRegistry, TokenizerRegistry,
...@@ -15,12 +12,9 @@ from .registry import ( ...@@ -15,12 +12,9 @@ from .registry import (
__all__ = [ __all__ = [
"TokenizerLike", "TokenizerLike",
"HfTokenizer",
"MistralTokenizer",
"TokenizerRegistry", "TokenizerRegistry",
"cached_get_tokenizer", "cached_get_tokenizer",
"get_tokenizer", "get_tokenizer",
"cached_tokenizer_from_config", "cached_tokenizer_from_config",
"init_tokenizer_from_config", "init_tokenizer_from_config",
"DeepseekV32Tokenizer",
] ]
...@@ -2,22 +2,18 @@ ...@@ -2,22 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path from pathlib import Path
from typing import Any
from transformers import BatchEncoding from transformers import BatchEncoding
from .deepseek_v32_encoding import encode_messages from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from .hf import HfTokenizer, TokenizerLike
from .registry import TokenizerRegistry
from .deepseek_v32_encoding import encode_messages
from .hf import CachedHfTokenizer
from .protocol import TokenizerLike
@TokenizerRegistry.register("deepseek_v32")
class DeepseekV32Tokenizer(HfTokenizer):
def __init__(self, tokenizer: TokenizerLike):
self.tokenizer = tokenizer
self.name_or_path = (
tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else ""
)
class DeepseekV32Tokenizer(CachedHfTokenizer):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -38,20 +34,47 @@ class DeepseekV32Tokenizer(HfTokenizer): ...@@ -38,20 +34,47 @@ class DeepseekV32Tokenizer(HfTokenizer):
) )
return DeepseekV32Tokenizer(tokenizer) return DeepseekV32Tokenizer(tokenizer)
def apply_chat_template(self, messages, tools=None, **kwargs): def __init__(self, tokenizer: TokenizerLike) -> None:
super().__init__()
self.tokenizer = tokenizer
self.name_or_path = getattr(tokenizer, "name_or_path", "")
self._added_vocab = self.tokenizer.get_added_vocab()
self._added_vocab_size = len(self._added_vocab)
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> str | list[int]:
thinking = kwargs.get("thinking", False) thinking = kwargs.get("thinking", False)
thinking_mode = "thinking" thinking_mode = "thinking"
if not thinking: if not thinking:
thinking_mode = "chat" thinking_mode = "chat"
conversation = kwargs.get("conversation", messages) conversation = kwargs.get("conversation", messages)
messages = conversation.copy() messages = conversation.copy()
drop_thinking = True
if tools is not None and len(tools) > 0: if tools is not None and len(tools) > 0:
messages.insert(0, {"role": "system"}) messages.insert(0, {"role": "system"})
messages[0]["tools"] = tools messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key]
drop_thinking = False
# Historical reasoning content is dropped when a new user message is introduced
drop_thinking = messages[-1]["role"] == "user"
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
prompt_str = encode_messages(messages, **encode_config) # type: ignore prompt_str = encode_messages(messages, **encode_config) # type: ignore
if kwargs.get("tokenize", True):
tokenizer_kwargs = {
k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
}
return self.encode(
prompt_str,
add_special_tokens=False,
**tokenizer_kwargs,
)
return prompt_str return prompt_str
def num_special_tokens_to_add(self) -> int: def num_special_tokens_to_add(self) -> int:
...@@ -98,7 +121,7 @@ class DeepseekV32Tokenizer(HfTokenizer): ...@@ -98,7 +121,7 @@ class DeepseekV32Tokenizer(HfTokenizer):
def __len__(self) -> int: def __len__(self) -> int:
# </think> is an added token in DeepseekV32 tokenizer # </think> is an added token in DeepseekV32 tokenizer
return self.vocab_size + len(self.get_added_vocab()) return self.vocab_size + self._added_vocab_size
def __call__( def __call__(
self, self,
...@@ -120,7 +143,7 @@ class DeepseekV32Tokenizer(HfTokenizer): ...@@ -120,7 +143,7 @@ class DeepseekV32Tokenizer(HfTokenizer):
return self.tokenizer.get_vocab() return self.tokenizer.get_vocab()
def get_added_vocab(self) -> dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
return self.tokenizer.get_added_vocab() return self._added_vocab.copy()
def encode( def encode(
self, self,
......
...@@ -3,22 +3,18 @@ ...@@ -3,22 +3,18 @@
import contextlib import contextlib
import copy import copy
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TypeAlias
from transformers import AutoTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from .protocol import TokenizerLike from .protocol import TokenizerLike
from .registry import TokenizerRegistry
if TYPE_CHECKING: HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
def get_cached_tokenizer( def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast",
) -> TokenizerLike:
""" """
By default, transformers will recompute multiple tokenizer properties By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. each time they are called, leading to a significant slowdown.
...@@ -65,11 +61,10 @@ def get_cached_tokenizer( ...@@ -65,11 +61,10 @@ def get_cached_tokenizer(
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
cached_tokenizer.__class__ = CachedTokenizer cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer # type: ignore return cached_tokenizer
@TokenizerRegistry.register("hf") class CachedHfTokenizer(TokenizerLike):
class HfTokenizer(TokenizerLike):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike): ...@@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike):
revision: str | None = None, revision: str | None = None,
download_dir: str | None = None, download_dir: str | None = None,
**kwargs, **kwargs,
) -> "TokenizerLike": ) -> HfTokenizer:
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
path_or_repo_id, path_or_repo_id,
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from .protocol import TokenizerLike from .protocol import TokenizerLike
from .registry import TokenizerRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import ( from mistral_common.protocol.instruct.request import (
...@@ -15,9 +16,6 @@ if TYPE_CHECKING: ...@@ -15,9 +16,6 @@ if TYPE_CHECKING:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers import BatchEncoding from transformers import BatchEncoding
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
try: try:
# Transformers v5 # Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend from transformers.tokenization_mistral_common import MistralCommonBackend
...@@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: ...@@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
return tokenizer.unk_id return tokenizer.unk_id
@TokenizerRegistry.register("mistral")
class MistralTokenizer(TokenizerLike): class MistralTokenizer(TokenizerLike):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
......
...@@ -97,7 +97,7 @@ class TokenizerLike(Protocol): ...@@ -97,7 +97,7 @@ class TokenizerLike(Protocol):
messages: list["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
**kwargs, **kwargs,
) -> list[int]: ) -> str | list[int]:
raise NotImplementedError raise NotImplementedError
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util import importlib.util
from collections.abc import Callable from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, overload from typing import TYPE_CHECKING
import huggingface_hub import huggingface_hub
from typing_extensions import assert_never from typing_extensions import TypeVar, assert_never, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike from .protocol import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config.model import ModelConfig, RunnerType
logger = init_logger(__name__) logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[TokenizerLike])
_VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"mistral": ("mistral", "MistralTokenizer"),
}
class TokenizerRegistry:
# Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class)
REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {}
# In-tree tokenizers @dataclass
@staticmethod class _TokenizerRegistry:
@overload # Tokenizer mode -> (tokenizer module, tokenizer class)
def register(tokenizer_mode: str) -> Callable[[_T], _T]: ... tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict)
# OOT tokenizers def register(self, tokenizer_mode: str, module: str, class_name: str) -> None:
@staticmethod if tokenizer_mode in self.tokenizers:
@overload
def register(tokenizer_mode: str, module: str, class_name: str) -> None: ...
@staticmethod
def register(
tokenizer_mode: str,
module: str | None = None,
class_name: str | None = None,
) -> Callable[[_T], _T] | None:
# In-tree tokenizers
if module is None or class_name is None:
def wrapper(tokenizer_cls: _T) -> _T:
assert tokenizer_mode not in TokenizerRegistry.REGISTRY
TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls
return tokenizer_cls
return wrapper
# OOT tokenizers
if tokenizer_mode in TokenizerRegistry.REGISTRY:
logger.warning( logger.warning(
"%s.%s is already registered for tokenizer_mode=%r. " "%s.%s is already registered for tokenizer_mode=%r. "
"It is overwritten by the new one.", "It is overwritten by the new one.",
...@@ -72,36 +51,42 @@ class TokenizerRegistry: ...@@ -72,36 +51,42 @@ class TokenizerRegistry:
tokenizer_mode, tokenizer_mode,
) )
TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name) self.tokenizers[tokenizer_mode] = (module, class_name)
return None return None
@staticmethod def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]:
def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike": if tokenizer_mode not in self.tokenizers:
if tokenizer_mode not in TokenizerRegistry.REGISTRY:
raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.") raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")
item = TokenizerRegistry.REGISTRY[tokenizer_mode] module, class_name = self.tokenizers[tokenizer_mode]
if isinstance(item, type):
return item.from_pretrained(*args, **kwargs)
module, class_name = item
logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}") logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
class_ = resolve_obj_by_qualname(f"{module}.{class_name}") return resolve_obj_by_qualname(f"{module}.{class_name}")
return class_.from_pretrained(*args, **kwargs)
def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike:
tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode)
return tokenizer_cls.from_pretrained(*args, **kwargs)
def get_tokenizer(
TokenizerRegistry = _TokenizerRegistry(
{
mode: (f"vllm.tokenizers.{mod_relname}", cls_name)
for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items()
}
)
def resolve_tokenizer_args(
tokenizer_name: str | Path, tokenizer_name: str | Path,
*args, *args,
runner_type: "RunnerType" = "generate",
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs, **kwargs,
) -> TokenizerLike: ):
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" revision: str | None = kwargs.get("revision")
download_dir: str | None = kwargs.get("download_dir")
if envs.VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub, # download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use. # lazy import so that modelscope is not required for normal use.
...@@ -125,16 +110,6 @@ def get_tokenizer( ...@@ -125,16 +110,6 @@ def get_tokenizer(
) )
tokenizer_name = tokenizer_path tokenizer_name = tokenizer_path
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
tokenizer_mode = "hf"
kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
if is_gguf(tokenizer_name): if is_gguf(tokenizer_name):
if check_gguf_file(tokenizer_name): if check_gguf_file(tokenizer_name):
...@@ -150,6 +125,21 @@ def get_tokenizer( ...@@ -150,6 +125,21 @@ def get_tokenizer(
) )
kwargs["gguf_file"] = gguf_file kwargs["gguf_file"] = gguf_file
if "truncation_side" not in kwargs:
if runner_type == "generate" or runner_type == "draft":
kwargs["truncation_side"] = "left"
elif runner_type == "pooling":
kwargs["truncation_side"] = "right"
else:
assert_never(runner_type)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
tokenizer_mode = "hf"
kwargs["use_fast"] = False
# Try to use official Mistral tokenizer if possible # Try to use official Mistral tokenizer if possible
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
allow_patterns = ["tekken.json", "tokenizer.model.v*"] allow_patterns = ["tekken.json", "tokenizer.model.v*"]
...@@ -165,49 +155,70 @@ def get_tokenizer( ...@@ -165,49 +155,70 @@ def get_tokenizer(
if tokenizer_mode == "auto": if tokenizer_mode == "auto":
tokenizer_mode = "hf" tokenizer_mode = "hf"
tokenizer_args = (tokenizer_name, *args) return tokenizer_mode, tokenizer_name, args, kwargs
tokenizer_kwargs = dict(
cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args)
def tokenizer_args_from_config(config: "ModelConfig", **kwargs):
return cached_resolve_tokenizer_args(
config.tokenizer,
runner_type=config.runner_type,
tokenizer_mode=config.tokenizer_mode,
revision=config.tokenizer_revision,
trust_remote_code=config.trust_remote_code,
**kwargs,
)
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
def get_tokenizer(
tokenizer_name: str | Path,
*args,
tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment]
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> _T:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
download_dir=download_dir, download_dir=download_dir,
**kwargs, **kwargs,
) )
if tokenizer_mode == "custom": if tokenizer_cls == TokenizerLike:
logger.warning_once( tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
"TokenizerRegistry now uses `tokenizer_mode` as the registry key " else:
"instead of `tokenizer_name`. " tokenizer_cls_ = tokenizer_cls
"Please update the definition of `.from_pretrained` in "
"your custom tokenizer to accept `args=%s`, `kwargs=%s`. "
"Then, you can pass `tokenizer_mode=%r` instead of "
"`tokenizer_mode='custom'` when initializing vLLM.",
tokenizer_args,
str(tokenizer_kwargs),
tokenizer_name,
)
tokenizer_mode = str(tokenizer_name)
tokenizer = TokenizerRegistry.get_tokenizer( tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs)
tokenizer_mode,
*tokenizer_args,
**tokenizer_kwargs,
)
if not tokenizer.is_fast: if not tokenizer.is_fast:
logger.warning( logger.warning(
"Using a slow tokenizer. This might cause a significant " "Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead." "slowdown. Consider using a fast tokenizer instead."
) )
return tokenizer return tokenizer # type: ignore
cached_get_tokenizer = lru_cache(get_tokenizer) cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
if model_config.skip_tokenizer_init:
return None
return cached_get_tokenizer( return cached_get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
runner_type=model_config.runner_type,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision, revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
...@@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): ...@@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
) )
@deprecated(
"Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14."
)
def init_tokenizer_from_config(model_config: "ModelConfig"): def init_tokenizer_from_config(model_config: "ModelConfig"):
runner_type = model_config.runner_type return cached_tokenizer_from_config(model_config)
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment