Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
# SPDX-License-Identifier: Apache-2.0
import os
import torch
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
# torch._inductor.config.compile_threads = 1
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import hashlib import hashlib
import os import os
import sys
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
...@@ -29,6 +30,7 @@ if TYPE_CHECKING: ...@@ -29,6 +30,7 @@ if TYPE_CHECKING:
S3_ACCESS_KEY_ID: Optional[str] = None S3_ACCESS_KEY_ID: Optional[str] = None
S3_SECRET_ACCESS_KEY: Optional[str] = None S3_SECRET_ACCESS_KEY: Optional[str] = None
S3_ENDPOINT_URL: Optional[str] = None S3_ENDPOINT_URL: Optional[str] = None
VLLM_MODEL_REDIRECT_PATH: Optional[str] = None
VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm") VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm")
VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm")
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
...@@ -53,7 +55,7 @@ if TYPE_CHECKING: ...@@ -53,7 +55,7 @@ if TYPE_CHECKING:
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
...@@ -81,9 +83,13 @@ if TYPE_CHECKING: ...@@ -81,9 +83,13 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
...@@ -99,12 +105,15 @@ if TYPE_CHECKING: ...@@ -99,12 +105,15 @@ if TYPE_CHECKING:
VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0 VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1 VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -302,6 +311,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -302,6 +311,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_API_KEY": "VLLM_API_KEY":
lambda: os.environ.get("VLLM_API_KEY", None), lambda: os.environ.get("VLLM_API_KEY", None),
# Whether to log responses from API Server for debugging
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
lower() == "true",
# S3 access information, used for tensorizer to load model from S3 # S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID": "S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY_ID", None), lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
...@@ -404,15 +418,21 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -404,15 +418,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
# (previously known as ADAG) API which optimizes the # (previously known as ADAG) API which optimizes the
# control plane overhead. # control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Note that this variable is set to 1 in V1 by default
# when ray distributed executor is used.
"VLLM_USE_RAY_COMPILED_DAG": "VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
# If the env var is set, it uses NCCL for communication in # If the env var is set, Ray Compiled Graph uses the specified
# Ray's Compiled Graph. This flag is ignored if # channel type to communicate between workers belonging to
# VLLM_USE_RAY_COMPILED_DAG is not set. # different pipeline-parallel stages.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": # Available options:
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) # - "auto": use the default channel type
), # - "nccl": use NCCL for communication
# - "shm": use shared memory and gRPC for communication
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"),
# If the env var is set, it enables GPU communication overlap # If the env var is set, it enables GPU communication overlap
# (experimental feature) in Ray's Compiled Graph. This flag is ignored if # (experimental feature) in Ray's Compiled Graph. This flag is ignored if
...@@ -554,6 +574,26 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -554,6 +574,26 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")), ("true", "1")),
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
"VLLM_ROCM_USE_AITER_LINEAR":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in
("true", "1")),
# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),
# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled. # use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM": "VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
...@@ -567,6 +607,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -567,6 +607,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_MOE_PADDING": "VLLM_ROCM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# custom paged attention kernel for MI3* cards
"VLLM_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache # Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT": "Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
...@@ -643,6 +688,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -643,6 +688,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_RANK": "VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")), lambda: int(os.getenv("VLLM_DP_RANK", "0")),
# Rank of the process in the data parallel setting.
# Defaults to VLLM_DP_RANK when not set.
"VLLM_DP_RANK_LOCAL":
lambda: int(
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),
# World size of the data parallel setting # World size of the data parallel setting
"VLLM_DP_SIZE": "VLLM_DP_SIZE":
lambda: int(os.getenv("VLLM_DP_SIZE", "1")), lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
...@@ -659,6 +710,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -659,6 +710,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CI_USE_S3": "VLLM_CI_USE_S3":
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder.
"VLLM_MODEL_REDIRECT_PATH":
lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None),
# Whether to use atomicAdd reduce in gptq/awq marlin kernel. # Whether to use atomicAdd reduce in gptq/awq marlin kernel.
"VLLM_MARLIN_USE_ATOMIC_ADD": "VLLM_MARLIN_USE_ATOMIC_ADD":
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
...@@ -673,6 +728,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -673,6 +728,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -51,6 +51,7 @@ class ExecutorBase(ABC): ...@@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
self.is_sleeping = False self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod @abstractmethod
def _init_executor(self) -> None: def _init_executor(self) -> None:
...@@ -204,20 +205,34 @@ class ExecutorBase(ABC): ...@@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter() time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level)) self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter() time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.", logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep) time_after_sleep - time_before_sleep)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping: if not self.is_sleeping:
logger.warning("Executor is not sleeping.") logger.warning("Executor is not sleeping.")
return return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter() time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up") self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter() time_after_wakeup = time.perf_counter()
self.is_sleeping = False logger.info("It took %.6f seconds to wake up tags %s.",
logger.info("It took %.6f seconds to wake up.", time_after_wakeup - time_before_wakeup,
time_after_wakeup - time_before_wakeup) tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state( def save_sharded_state(
self, self,
......
...@@ -79,7 +79,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -79,7 +79,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# For TPU, avoid compiling NVIDIA's NCCL # For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu(): if current_platform.is_tpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0" os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
...@@ -546,10 +546,11 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -546,10 +546,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
"Run `pip install ray[cgraph]` to install it.") "Run `pip install ray[cgraph]` to install it.")
cupy_spec = importlib.util.find_spec("cupy") cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: if (cupy_spec is None
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
raise ValueError( raise ValueError(
"cupy is not installed but required since " "cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. " "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
"Run `pip install ray[cgraph]` and check cupy installation.") "Run `pip install ray[cgraph]` and check cupy installation.")
def _compiled_ray_dag(self, enable_asyncio: bool): def _compiled_ray_dag(self, enable_asyncio: bool):
...@@ -557,10 +558,17 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -557,10 +558,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._check_ray_cgraph_installation() self._check_ray_cgraph_installation()
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if channel_type not in ("auto", "nccl", "shm"):
raise ValueError(
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to # (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution, # control the timeout of getting result from a compiled graph execution,
...@@ -605,13 +613,12 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -605,13 +613,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
] ]
last_pp_rank = len(self.pp_tp_workers) - 1 last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank: if (pp_rank < last_pp_rank and
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
# Specify how intermediate tensors should be passed # Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last # between pp stages, no need to specify for the last
# pp stage. # pp stage or when using shared memory (the default).
transport = "nccl" \ transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
outputs = [ outputs = [
output.with_tensor_transport(transport=transport) output.with_tensor_transport(transport=transport)
for output in outputs for output in outputs
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt( ...@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: inputs: ProcessorInputs,
return "encoder" in inputs and "decoder" in inputs ) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs
...@@ -261,13 +261,13 @@ class InputPreprocessor: ...@@ -261,13 +261,13 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal # initialized without a tokenizer while using also multi-modal
# input. # input.
if not self.tokenizer: if not self.tokenizer:
tokenizer = None tokenizer = object() # Dummy
else: else:
tokenizer_group = self.get_tokenizer_group() tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(self.model_config,
self.model_config, tokenizer) tokenizer=tokenizer)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -288,14 +288,14 @@ class InputPreprocessor: ...@@ -288,14 +288,14 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal # initialized without a tokenizer while using also multi-modal
# input. # input.
if not self.tokenizer: if not self.tokenizer:
tokenizer = None tokenizer = object() # Dummy
else: else:
tokenizer_group = self.get_tokenizer_group() tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async( tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request) lora_request)
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(self.model_config,
self.model_config, tokenizer) tokenizer=tokenizer)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -528,6 +528,7 @@ class InputPreprocessor: ...@@ -528,6 +528,7 @@ class InputPreprocessor:
prompt_token_ids=decoder_inputs_to_override[ prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"], "prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
else: else:
...@@ -536,6 +537,7 @@ class InputPreprocessor: ...@@ -536,6 +537,7 @@ class InputPreprocessor:
prompt=inputs["prompt"], prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"], prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
elif inputs["type"] == "token": elif inputs["type"] == "token":
......
...@@ -13,13 +13,12 @@ from typing_extensions import TypeVar, assert_never ...@@ -13,13 +13,12 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import AnyTokenizer
cached_tokenizer_from_config)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs from .parse import split_enc_dec_inputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -329,17 +328,27 @@ class InputRegistry: ...@@ -329,17 +328,27 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.sequence import SequenceData
if mm_registry.has_processor(model_config): if mm_registry.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = mm_registry.create_processor(model_config, processor = mm_registry.create_processor(model_config,
tokenizer,
disable_cache=True) disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
profiler.get_decoder_dummy_data) if is_encoder_data else
dummy_data = dummy_data_factory(seq_len) profiler.get_decoder_dummy_data(seq_len))
_seq_data = SequenceData.from_seqs(
dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined]
dummy_data = DummyData(
seq_data=_seq_data,
multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
None),
multi_modal_placeholders=getattr(dummy_data_v1,
"multi_modal_placeholders",
None),
)
else: else:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
if is_encoder_data: if is_encoder_data:
...@@ -462,13 +471,11 @@ class InputRegistry: ...@@ -462,13 +471,11 @@ class InputRegistry:
**mm_processor_kwargs, **mm_processor_kwargs,
) )
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._ensure_mm_kwargs(processed_inputs["encoder"], if encoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"], if decoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
return processed_inputs return processed_inputs
......
...@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel): ...@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}." f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct") f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device) tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else: else:
raise ValueError(f"{lora_dir} doesn't contain tensors") raise ValueError(f"{lora_dir} doesn't contain tensors")
......
...@@ -130,7 +130,7 @@ def do_expand_kernel( ...@@ -130,7 +130,7 @@ def do_expand_kernel(
# Identify A and B block pointers # Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K) offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, ) offset_k[None, :] * input_d2_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride + offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride) rbn[None, :] * cur_lora_d1_stride)
......
...@@ -136,6 +136,7 @@ def _lora_expand( ...@@ -136,6 +136,7 @@ def _lora_expand(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
...@@ -157,11 +158,19 @@ def _lora_expand( ...@@ -157,11 +158,19 @@ def _lora_expand(
identifies the the region in token_indices_sorted_by_lora_ids that identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process. LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor. offset_start (int, optional): Offset start for output_tensor.
Defaults to 0. Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False. output tensor. Defaults to False.
""" """
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights: for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16] assert weight.dtype in [torch.float16, torch.bfloat16]
...@@ -170,6 +179,8 @@ def _lora_expand( ...@@ -170,6 +179,8 @@ def _lora_expand(
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
# metadata sanity check. # metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0) 0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
...@@ -181,7 +192,6 @@ def _lora_expand( ...@@ -181,7 +192,6 @@ def _lora_expand(
inputs.device) inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0) MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False CAST_TYPE = False
...@@ -263,6 +273,7 @@ def _lora_expand_fake( ...@@ -263,6 +273,7 @@ def _lora_expand_fake(
num_tokens_per_lora: torch.Tensor, num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
......
...@@ -17,6 +17,17 @@ class LoRAKernelMeta: ...@@ -17,6 +17,17 @@ class LoRAKernelMeta:
num_tokens_per_lora: torch.Tensor num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod @staticmethod
def make(max_loras: int, max_num_tokens: int, def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta": device: Union[torch.device, str]) -> "LoRAKernelMeta":
...@@ -47,17 +58,24 @@ class LoRAKernelMeta: ...@@ -47,17 +58,24 @@ class LoRAKernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2, lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
no_lora_flag_cpu = torch.tensor([False],
dtype=torch.bool,
device='cpu')
return LoRAKernelMeta( return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping, token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids, active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora, num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc) lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu)
def _reset(self): def _reset(self):
self.active_lora_ids.fill_(-1) self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0) self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0) self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
""" """
...@@ -70,6 +88,14 @@ class LoRAKernelMeta: ...@@ -70,6 +88,14 @@ class LoRAKernelMeta:
self._reset() self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0) num_tokens = token_lora_mapping.size(0)
# copy token lora mapping # copy token lora mapping
...@@ -100,7 +126,7 @@ class LoRAKernelMeta: ...@@ -100,7 +126,7 @@ class LoRAKernelMeta:
def meta_args( def meta_args(
self, token_nums: int self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]: torch.Tensor, torch.Tensor]:
""" """
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the forward pass execution of the kernel. The function returns all the
...@@ -111,7 +137,11 @@ class LoRAKernelMeta: ...@@ -111,7 +137,11 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward token_nums (int): Number of input tokens in the current forward
pass. pass.
""" """
return (self.token_lora_mapping[:token_nums], return (
self.token_indices_sorted_by_lora_ids[:token_nums], self.token_lora_mapping[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc, self.token_indices_sorted_by_lora_ids[:token_nums],
self.active_lora_ids) self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)
...@@ -106,6 +106,7 @@ def _lora_shrink( ...@@ -106,6 +106,7 @@ def _lora_shrink(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
...@@ -126,8 +127,16 @@ def _lora_shrink( ...@@ -126,8 +127,16 @@ def _lora_shrink(
identifies the region in token_indices_sorted_by_lora_ids that identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process. LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor. scaling (float): Scaling factor.
""" """
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16] assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights: for weight in lora_a_weights:
...@@ -138,6 +147,8 @@ def _lora_shrink( ...@@ -138,6 +147,8 @@ def _lora_shrink(
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
# metadata sanity check # metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0) 0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
...@@ -146,7 +157,6 @@ def _lora_shrink( ...@@ -146,7 +157,6 @@ def _lora_shrink(
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, (lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights) NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0) MAX_LORAS = lora_ids.size(0)
...@@ -218,6 +228,7 @@ def _lora_shrink_fake( ...@@ -218,6 +228,7 @@ def _lora_shrink_fake(
num_tokens_per_lora: torch.Tensor, num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float, scaling: float,
) -> None: ) -> None:
return return
......
...@@ -5,10 +5,10 @@ from __future__ import annotations ...@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import ( from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark, convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -79,12 +79,6 @@ def maybe_backend_fallback( ...@@ -79,12 +79,6 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the " "xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines") "grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully, # If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback. # we should still allow users to use guided decoding with a fallback.
elif not xgr_installed: elif not xgr_installed:
...@@ -107,7 +101,11 @@ async def get_guided_decoding_logits_processor( ...@@ -107,7 +101,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig, model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend) reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
...@@ -146,8 +144,11 @@ def get_local_guided_decoding_logits_processor( ...@@ -146,8 +144,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_ reasoner = None
reasoner = get_reasoner(tokenizer, reasoning_backend) if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines': if guided_params.backend_name == 'outlines':
......
...@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor( ...@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
""" """
grm = "" grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json: if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema( grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json, guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern}) overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.json_object: elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema( grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}', '{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern}) overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.regex: elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex) grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice: elif guided_params.choice:
......
...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16 ...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -141,7 +141,7 @@ def _get_logits_processor( ...@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,9 +49,9 @@ else: ...@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor: class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int, self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
...@@ -69,7 +69,7 @@ class BaseLogitsProcessor: ...@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids # Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the # We need this because our implementation relies on the
# hash of the input_ids to store the FSM state. # hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids) input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self, self,
regex_string: str, regex_string: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
): ):
"""Compile the FSM that drives the regex-structured generation. """Compile the FSM that drives the regex-structured generation.
...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel], def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters
...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer) return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation. """Compile the FSM that drives the context free grammar generation.
Parameters Parameters
......
...@@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer, ...@@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
return None return None
elif reasoning_backend == "deepseek_r1": elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer) return DeepSeekReasoner.from_tokenizer(tokenizer)
elif reasoning_backend == "granite":
logger.warning(
"Granite reasoner not yet implemented for structured outputs")
return None
else: else:
# Raise a warning for unknown reasoning backend and return None # Raise a warning for unknown reasoning backend and return None
# We cannot raise an error here because some reasoning models # We cannot raise an error here because some reasoning models
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
...@@ -27,7 +27,7 @@ if TYPE_CHECKING: ...@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( ...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig, model_config: ModelConfig,
reasoner: Reasoner | None, reasoner: ReasoningParser | None,
max_threads: int = 8): max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params, config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config, model_config=model_config,
...@@ -280,7 +280,7 @@ class GrammarConfig: ...@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor: class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol""" """Wrapper class to support pickle protocol"""
config: GrammarConfig config: GrammarConfig
reasoner: Reasoner | None = None reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
...@@ -320,7 +320,10 @@ class XGrammarLogitsProcessor: ...@@ -320,7 +320,10 @@ class XGrammarLogitsProcessor:
elif self.config.grammar_str is not None: elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str) self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object: elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar() any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
else: else:
raise ValueError( raise ValueError(
"Invalid configuration for xgrammar logits processor") "Invalid configuration for xgrammar logits processor")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment