Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
# SPDX-License-Identifier: Apache-2.0
import os
import torch
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
# see https://github.com/vllm-project/vllm/issues/10480
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
# torch._inductor.config.compile_threads = 1
......@@ -2,6 +2,7 @@
import hashlib
import os
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional
......@@ -29,6 +30,7 @@ if TYPE_CHECKING:
S3_ACCESS_KEY_ID: Optional[str] = None
S3_SECRET_ACCESS_KEY: Optional[str] = None
S3_ENDPOINT_URL: Optional[str] = None
VLLM_MODEL_REDIRECT_PATH: Optional[str] = None
VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm")
VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm")
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
......@@ -53,7 +55,7 @@ if TYPE_CHECKING:
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
......@@ -81,9 +83,13 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
......@@ -99,12 +105,15 @@ if TYPE_CHECKING:
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
def get_default_cache_root():
......@@ -302,6 +311,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_API_KEY":
lambda: os.environ.get("VLLM_API_KEY", None),
# Whether to log responses from API Server for debugging
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
lower() == "true",
# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
......@@ -404,15 +418,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
# (previously known as ADAG) API which optimizes the
# control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Note that this variable is set to 1 in V1 by default
# when ray distributed executor is used.
"VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
# If the env var is set, it uses NCCL for communication in
# Ray's Compiled Graph. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
),
# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
# Available options:
# - "auto": use the default channel type
# - "nccl": use NCCL for communication
# - "shm": use shared memory and gRPC for communication
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"),
# If the env var is set, it enables GPU communication overlap
# (experimental feature) in Ray's Compiled Graph. This flag is ignored if
......@@ -554,6 +574,26 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
"VLLM_ROCM_USE_AITER_LINEAR":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in
("true", "1")),
# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),
# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
......@@ -567,6 +607,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# custom paged attention kernel for MI3* cards
"VLLM_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
......@@ -643,6 +688,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
# Rank of the process in the data parallel setting.
# Defaults to VLLM_DP_RANK when not set.
"VLLM_DP_RANK_LOCAL":
lambda: int(
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),
# World size of the data parallel setting
"VLLM_DP_SIZE":
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
......@@ -659,6 +710,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CI_USE_S3":
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
# Use model_redirect to redirect the model name to a local folder.
"VLLM_MODEL_REDIRECT_PATH":
lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None),
# Whether to use atomicAdd reduce in gptq/awq marlin kernel.
"VLLM_MARLIN_USE_ATOMIC_ADD":
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
......@@ -673,6 +728,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
}
# end-env-vars-definition
......
......@@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod
def _init_executor(self) -> None:
......@@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep)
def wake_up(self):
def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up")
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
self.is_sleeping = False
logger.info("It took %.6f seconds to wake up.",
time_after_wakeup - time_before_wakeup)
logger.info("It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state(
self,
......
......@@ -79,7 +79,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
......@@ -546,10 +546,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
"Run `pip install ray[cgraph]` to install it.")
cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
if (cupy_spec is None
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
raise ValueError(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. "
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
"Run `pip install ray[cgraph]` and check cupy installation.")
def _compiled_ray_dag(self, enable_asyncio: bool):
......@@ -557,10 +558,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._check_ray_cgraph_installation()
from ray.dag import InputNode, MultiOutputNode
logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if channel_type not in ("auto", "nccl", "shm"):
raise ValueError(
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
......@@ -605,13 +613,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
if (pp_rank < last_pp_rank and
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "nccl" \
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
# pp stage or when using shared memory (the default).
transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
outputs = [
output.with_tensor_transport(transport=transport)
for output in outputs
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload
from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
......@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
return "encoder" in inputs and "decoder" in inputs
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs
......@@ -261,13 +261,13 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
......@@ -288,14 +288,14 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
......@@ -528,6 +528,7 @@ class InputPreprocessor:
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
......@@ -536,6 +537,7 @@ class InputPreprocessor:
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
......
......@@ -13,13 +13,12 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
from .parse import split_enc_dec_inputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -329,17 +328,27 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.sequence import SequenceData
if mm_registry.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = mm_registry.create_processor(model_config,
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else
profiler.get_decoder_dummy_data)
dummy_data = dummy_data_factory(seq_len)
dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
if is_encoder_data else
profiler.get_decoder_dummy_data(seq_len))
_seq_data = SequenceData.from_seqs(
dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined]
dummy_data = DummyData(
seq_data=_seq_data,
multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
None),
multi_modal_placeholders=getattr(dummy_data_v1,
"multi_modal_placeholders",
None),
)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
......@@ -462,13 +471,11 @@ class InputRegistry:
**mm_processor_kwargs,
)
if is_encoder_decoder_inputs(processed_inputs):
self._ensure_mm_kwargs(processed_inputs["encoder"],
mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"],
mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
if encoder_inputs is not None:
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
if decoder_inputs is not None:
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
return processed_inputs
......
......@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device)
tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
......
......@@ -130,7 +130,7 @@ def do_expand_kernel(
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
offset_k[None, :] * input_d2_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
......
......@@ -136,6 +136,7 @@ def _lora_expand(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
......@@ -157,11 +158,19 @@ def _lora_expand(
identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
......@@ -170,6 +179,8 @@ def _lora_expand(
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
......@@ -181,7 +192,6 @@ def _lora_expand(
inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
......@@ -263,6 +273,7 @@ def _lora_expand_fake(
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
......
......@@ -17,6 +17,17 @@ class LoRAKernelMeta:
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod
def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta":
......@@ -47,17 +58,24 @@ class LoRAKernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32,
device=device)
no_lora_flag_cpu = torch.tensor([False],
dtype=torch.bool,
device='cpu')
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc)
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
......@@ -70,6 +88,14 @@ class LoRAKernelMeta:
self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
......@@ -100,7 +126,7 @@ class LoRAKernelMeta:
def meta_args(
self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
torch.Tensor, torch.Tensor]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
......@@ -111,7 +137,11 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward
pass.
"""
return (self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc,
self.active_lora_ids)
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)
......@@ -106,6 +106,7 @@ def _lora_shrink(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float,
) -> None:
"""
......@@ -126,8 +127,16 @@ def _lora_shrink(
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
......@@ -138,6 +147,8 @@ def _lora_shrink(
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
......@@ -146,7 +157,6 @@ def _lora_shrink(
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
......@@ -218,6 +228,7 @@ def _lora_shrink_fake(
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float,
) -> None:
return
......
......@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
......@@ -79,12 +79,6 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
......@@ -107,7 +101,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params)
......@@ -146,8 +144,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
......
......@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
"""
grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
......
......@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
......@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
......@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
......@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
......@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids)
input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids))
......@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.
......@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
......@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
......
......@@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
return None
elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer)
elif reasoning_backend == "granite":
logger.warning(
"Granite reasoner not yet implemented for structured outputs")
return None
else:
# Raise a warning for unknown reasoning backend and return None
# We cannot raise an error here because some reasoning models
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
......@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
......@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoner: Reasoner | None,
reasoner: ReasoningParser | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
......@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
reasoner: Reasoner | None = None
reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
......@@ -320,7 +320,10 @@ class XGrammarLogitsProcessor:
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar()
any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment