Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
...@@ -69,12 +69,12 @@ class CpuPlatform(Platform): ...@@ -69,12 +69,12 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
ipex_avaliable = find_spec("intel_extension_for_pytorch") is not None ipex_available = find_spec("intel_extension_for_pytorch") is not None
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 128 if ipex_avaliable else 16 cache_config.block_size = 128 if ipex_available else 16
if not ipex_avaliable and cache_config.block_size != 16: if not ipex_available and cache_config.block_size != 16:
raise RuntimeError( raise RuntimeError(
f"--block-size={cache_config.block_size} requires" f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch") " intel_extension_for_pytorch")
......
...@@ -46,15 +46,15 @@ class HpuPlatform(Platform): ...@@ -46,15 +46,15 @@ class HpuPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
parallel_config = vllm_config.parallel_config
if scheduler_config.is_multi_step: if scheduler_config.is_multi_step:
raise NotImplementedError( parallel_config.worker_cls = \
"Multi-step execution is not implemented for HPU") "vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
if vllm_config.speculative_config is not None: if vllm_config.speculative_config is not None:
raise NotImplementedError( raise NotImplementedError(
"Speculative decoding is not implemented for HPU") "Speculative decoding is not implemented for HPU")
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
import platform import platform
import random import random
...@@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union ...@@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
else: else:
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
LoRARequest = None
PoolingParams = None
SamplingParams = None
FlexibleArgumentParser = None FlexibleArgumentParser = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -231,7 +237,7 @@ class Platform: ...@@ -231,7 +237,7 @@ class Platform:
parser: Optional[FlexibleArgumentParser] = None parser: Optional[FlexibleArgumentParser] = None
) -> None: ) -> None:
""" """
Do some pre-registeration or update action for the current platform. Do some pre-registration or update action for the current platform.
This function is called before global VllmConfig is initialized or cli This function is called before global VllmConfig is initialized or cli
arguments are parsed. It's used for out-of-tree platforms to register or arguments are parsed. It's used for out-of-tree platforms to register or
...@@ -386,6 +392,14 @@ class Platform: ...@@ -386,6 +392,14 @@ class Platform:
""" """
return False return False
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
else: else:
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
PoolingParams = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -116,6 +120,13 @@ class TpuPlatform(Platform): ...@@ -116,6 +120,13 @@ class TpuPlatform(Platform):
assert not vllm_config.speculative_config, ( assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend") "Speculative decoding is not yet supported for TPU backend")
if scheduler_config.is_multimodal_model and not \
scheduler_config.disable_chunked_mm_input:
logger.warning("TPU does not support running Multimodal models"\
" without setting `--disable_chunked_mm_input`. " \
"Forcing --disable_chunked_mm_input.")
scheduler_config.disable_chunked_mm_input = True
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.") logger.warning("Pin memory is not supported on TPU.")
...@@ -133,3 +144,18 @@ class TpuPlatform(Platform): ...@@ -133,3 +144,18 @@ class TpuPlatform(Platform):
def supports_v1(cls, model_config: ModelConfig) -> bool: def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on TPU is experimental # V1 support on TPU is experimental
return True return True
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params, SamplingParams):
if params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
if TYPE_CHECKING:
from vllm.config import ModelConfig
class PoolingParams( class PoolingParams(
msgspec.Struct, msgspec.Struct,
...@@ -12,14 +15,30 @@ class PoolingParams( ...@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder. """API parameters for pooling models. This is currently a placeholder.
Attributes: Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling. additional_data: Any additional data needed for pooling.
""" """
dimensions: Optional[int] = None
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
return PoolingParams(additional_data=self.additional_data) return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
def verify(self, model_config: "ModelConfig") -> None:
if self.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f'support matryoshka representation, '
f'changing output dimensions will lead to poor results.')
if self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"PoolingParams(" return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})") f"additional_metadata={self.additional_data})")
...@@ -60,7 +60,7 @@ class GraniteReasoningParser(ReasoningParser): ...@@ -60,7 +60,7 @@ class GraniteReasoningParser(ReasoningParser):
Args: Args:
model_output (str): Output of the model to be parsed. model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed. request (ChatCompletionRequest): Request being processed.
Returns: Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the tuple[Optional[str], Optional[str]]: Tuple pair containing the
......
...@@ -101,7 +101,7 @@ class RequestOutputKind(Enum): ...@@ -101,7 +101,7 @@ class RequestOutputKind(Enum):
CUMULATIVE = 0 CUMULATIVE = 0
# Return only deltas in each RequestOutput # Return only deltas in each RequestOutput
DELTA = 1 DELTA = 1
# Do not return intermediate RequestOuputs # Do not return intermediate RequestOutput
FINAL_ONLY = 2 FINAL_ONLY = 2
...@@ -385,9 +385,10 @@ class SamplingParams( ...@@ -385,9 +385,10 @@ class SamplingParams(
if not -2.0 <= self.frequency_penalty <= 2.0: if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got " raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.") f"{self.frequency_penalty}.")
if not 0.0 < self.repetition_penalty <= 2.0: if self.repetition_penalty <= 0.0:
raise ValueError("repetition_penalty must be in (0, 2], got " raise ValueError(
f"{self.repetition_penalty}.") "repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0: if self.temperature < 0.0:
raise ValueError( raise ValueError(
f"temperature must be non-negative, got {self.temperature}.") f"temperature must be non-negative, got {self.temperature}.")
......
...@@ -1119,7 +1119,7 @@ class _PrintableStructure(Structure): ...@@ -1119,7 +1119,7 @@ class _PrintableStructure(Structure):
e.g. class that has _field_ 'hex_value', c_uint could be formatted with e.g. class that has _field_ 'hex_value', c_uint could be formatted with
_fmt_ = {"hex_value" : "%08X"} _fmt_ = {"hex_value" : "%08X"}
to produce nicer output. to produce nicer output.
Default fomratting string for all fields can be set with key "<default>" like: Default formatting string for all fields can be set with key "<default>" like:
_fmt_ = {"<default>" : "%d MHz"} # e.g all values are numbers in MHz. _fmt_ = {"<default>" : "%d MHz"} # e.g all values are numbers in MHz.
If not set it's assumed to be just "%s" If not set it's assumed to be just "%s"
......
...@@ -712,6 +712,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], ...@@ -712,6 +712,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
def get_hf_image_processor_config( def get_hf_image_processor_config(
model: Union[str, Path], model: Union[str, Path],
hf_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -721,7 +722,10 @@ def get_hf_image_processor_config( ...@@ -721,7 +722,10 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
if check_gguf_file(model): if check_gguf_file(model):
model = Path(model).parent model = Path(model).parent
return get_image_processor_config(model, revision=revision, **kwargs) return get_image_processor_config(model,
token=hf_token,
revision=revision,
**kwargs)
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
......
...@@ -5,6 +5,7 @@ from typing import Optional, Union ...@@ -5,6 +5,7 @@ from typing import Optional, Union
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
import vllm.envs as envs
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
...@@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig): ...@@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
self.truncated_vocab_size = self.model.vocab_size if \ self.truncated_vocab_size = self.model.vocab_size if \
truncated_vocab_size is None else truncated_vocab_size truncated_vocab_size is None else truncated_vocab_size
if "architectures" not in kwargs: if not envs.VLLM_USE_V1:
kwargs["architectures"] = ["EAGLEModel"] kwargs["architectures"] = ["EAGLEModel"]
else:
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
super().__init__(**kwargs) super().__init__(**kwargs)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids) truncate_tool_call_ids, validate_request_params)
__all__ = [ __all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids" "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids",
"validate_request_params"
] ]
...@@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): ...@@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request.messages[i]["tool_call_id"] = tool_call_id request.messages[i]["tool_call_id"] = tool_call_id
def validate_request_params(request: "ChatCompletionRequest"):
if (request.skip_special_tokens is not None
and not request.skip_special_tokens):
raise ValueError("skip_special_tokens=False is not supported "
"for Mistral tokenizers.")
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join( repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE, huggingface_hub.constants.HF_HUB_CACHE,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
from functools import cache from functools import cache
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
...@@ -51,6 +52,26 @@ def modelscope_list_repo_files( ...@@ -51,6 +52,26 @@ def modelscope_list_repo_files(
return files return files
def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]:
with open(path) as f:
try:
return json.loads(f.read())
except Exception:
return dict[str, str]()
def _maybe_space_split_dict(path: Union[str, PathLike]) -> dict[str, str]:
parsed_dict = dict[str, str]()
with open(path) as f:
for line in f.readlines():
try:
model_name, redirect_name = line.strip().split()
parsed_dict[model_name] = redirect_name
except Exception:
pass
return parsed_dict
@cache @cache
def maybe_model_redirect(model: str) -> str: def maybe_model_redirect(model: str) -> str:
""" """
...@@ -68,16 +89,10 @@ def maybe_model_redirect(model: str) -> str: ...@@ -68,16 +89,10 @@ def maybe_model_redirect(model: str) -> str:
if not Path(model_redirect_path).exists(): if not Path(model_redirect_path).exists():
return model return model
with open(model_redirect_path) as f: redirect_dict = (_maybe_json_dict(model_redirect_path)
for line in f.readlines(): or _maybe_space_split_dict(model_redirect_path))
try: if (redirect_model := redirect_dict.get(model)):
model_name, redirect_name = line.split("\t") logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model)
if model == model_name: return redirect_model
redirect_name = redirect_name.strip()
logger.info("model redirect: [ %s ] -> [ %s ]", model,
redirect_name)
return redirect_name
except Exception:
pass
return model return model
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import argparse
import asyncio import asyncio
import concurrent import concurrent
import contextlib import contextlib
...@@ -25,6 +24,7 @@ import socket ...@@ -25,6 +24,7 @@ import socket
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import textwrap
import threading import threading
import time import time
import traceback import traceback
...@@ -32,6 +32,8 @@ import types ...@@ -32,6 +32,8 @@ import types
import uuid import uuid
import warnings import warnings
import weakref import weakref
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
ArgumentTypeError)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
...@@ -40,7 +42,7 @@ from dataclasses import dataclass, field ...@@ -40,7 +42,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union, cast, overload) Optional, Tuple, Type, TypeVar, Union, cast, overload)
from uuid import uuid4 from uuid import uuid4
import cachetools import cachetools
...@@ -53,6 +55,7 @@ import torch.types ...@@ -53,6 +55,7 @@ import torch.types
import yaml import yaml
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from packaging import version
from packaging.version import Version from packaging.version import Version
from torch.library import Library from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never from typing_extensions import Never, ParamSpec, TypeIs, assert_never
...@@ -1209,7 +1212,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]: ...@@ -1209,7 +1212,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
return wrapper return wrapper
class StoreBoolean(argparse.Action): class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true": if values.lower() == "true":
...@@ -1221,15 +1224,28 @@ class StoreBoolean(argparse.Action): ...@@ -1221,15 +1224,28 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'.") "Expected 'true' or 'false'.")
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter): class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings.""" """SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile(r"\n{2,}\s*")
text = single_newline.sub(' ', text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions): def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings) actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions) super().add_arguments(actions)
class FlexibleArgumentParser(argparse.ArgumentParser): class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names.""" """ArgumentParser that allows both underscore and dash in names."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -1280,11 +1296,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -1280,11 +1296,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
value = int(value) value = int(value)
except ValueError: except ValueError:
msg = "Port must be an integer" msg = "Port must be an integer"
raise argparse.ArgumentTypeError(msg) from None raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535): if not (1024 <= value <= 65535):
raise argparse.ArgumentTypeError( raise ArgumentTypeError("Port must be between 1024 and 65535")
"Port must be between 1024 and 65535")
return value return value
...@@ -2060,12 +2075,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa ...@@ -2060,12 +2075,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op( def direct_register_custom_op(
op_name: str, op_name: str,
op_func: Callable, op_func: Callable,
mutates_args: list[str], mutates_args: list[str],
fake_impl: Optional[Callable] = None, fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None, target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA", dispatch_key: str = "CUDA",
tags: Tuple[torch.Tag, ...] = (),
): ):
""" """
`torch.library.custom_op` can have significant overhead because it `torch.library.custom_op` can have significant overhead because it
...@@ -2104,7 +2120,7 @@ def direct_register_custom_op( ...@@ -2104,7 +2120,7 @@ def direct_register_custom_op(
import torch._custom_op.impl import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str) my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)
...@@ -2689,3 +2705,20 @@ def sha256(input) -> int: ...@@ -2689,3 +2705,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(), return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big") byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)
...@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops ...@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -164,9 +164,9 @@ def make_local_attention_virtual_batches( ...@@ -164,9 +164,9 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int, attn_chunk_size: int,
query_start_loc_np: np.ndarray, query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray, seq_lens_np: np.ndarray,
block_table: torch.tensor, block_table: torch.Tensor,
page_size: int = 0, page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0] actual_batch_size = seq_lens_np.shape[0]
...@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches( ...@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
np.arange(pages_per_local_batch, dtype=np.int32), np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch)) \ (virtual_batches, pages_per_local_batch)) \
+ np.expand_dims(block_starts, axis=1) + np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten() block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch) local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\ block_table_local = block_table[batch_indices, block_indices]\
......
...@@ -83,8 +83,8 @@ spda_o = scaled_dot_product_attention( ...@@ -83,8 +83,8 @@ spda_o = scaled_dot_product_attention(
return spda_o @ W_O return spda_o @ W_O
NOTE: in the actual code, NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatnated per head `kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatnated per head `q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O `out_proj` is W_O
...@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops ...@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
......
...@@ -10,6 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401 ...@@ -10,6 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.logger import init_logger
logger = init_logger(__name__)
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
...@@ -80,7 +83,12 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -80,7 +83,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None, blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
use_irope: bool = False,
) -> None: ) -> None:
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does " raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.") "not support block-sparse attention.")
......
...@@ -67,11 +67,11 @@ class BlockPool: ...@@ -67,11 +67,11 @@ class BlockPool:
Returns: Returns:
The cached block if it exists, or None. The cached block if it exists, or None.
""" """
if block_hash in self.cached_block_hash_to_block: cached_blocks = self.cached_block_hash_to_block.get(block_hash)
first_block_id = list( if not cached_blocks:
self.cached_block_hash_to_block[block_hash].keys())[0] return None
return self.cached_block_hash_to_block[block_hash][first_block_id] first_block_id = next(iter(cached_blocks))
return None return cached_blocks[first_block_id]
def cache_full_blocks( def cache_full_blocks(
self, self,
......
...@@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal( ...@@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal(
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1]) key=lambda item: item[1])
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
> scheduler_config.max_num_batched_tokens):
raise ValueError(
"Chunked MM input disabled but max_tokens_per_mm_item "
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
"max_num_batched_tokens.")
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item) max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size, encoder_cache_size = max(scheduler_config.encoder_cache_size,
......
...@@ -126,44 +126,46 @@ class KVCacheManager: ...@@ -126,44 +126,46 @@ class KVCacheManager:
self.req_to_block_hashes[request.request_id] = block_hashes self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1 self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None: # When the request requires prompt logprobs, we skip prefix caching.
if len(block_hashes) * self.block_size == request.num_tokens: if request.sampling_params.prompt_logprobs is not None:
# When prompt length is divisible by the block size and all return [], 0
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
computed_blocks = ( if len(block_hashes) * self.block_size == request.num_tokens:
self.specialized_manager.find_longest_cache_hit(block_hashes)) # When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
if last_block_hash is not None: computed_blocks = (
# Add back the last block hash if it was removed. self.specialized_manager.find_longest_cache_hit(block_hashes))
block_hashes.append(last_block_hash) self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
self.prefix_cache_stats.queries += len(block_hashes) if last_block_hash is not None:
self.prefix_cache_stats.hits += len(computed_blocks) # Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)
# NOTE(woosuk): Since incomplete blocks are not eligible for # NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens return computed_blocks, num_computed_tokens
else:
# Skip cache hits for prompt logprobs
return [], 0
def allocate_slots( def allocate_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]: ) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append. """Add slots for a request with new tokens to append.
...@@ -173,6 +175,9 @@ class KVCacheManager: ...@@ -173,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed. not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the new_computed_blocks: A list of new computed blocks just hitting the
prefix caching. prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout: Blocks layout:
----------------------------------------------------------------------- -----------------------------------------------------------------------
...@@ -210,8 +215,9 @@ class KVCacheManager: ...@@ -210,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size) len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens, num_required_blocks = cdiv(
self.block_size) num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) - num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks)) len(new_computed_blocks))
...@@ -245,8 +251,11 @@ class KVCacheManager: ...@@ -245,8 +251,11 @@ class KVCacheManager:
else: else:
# Get new blocks from the free block pool considering # Get new blocks from the free block pool considering
# preallocated blocks. # preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
num_new_blocks = min( num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks, num_new_blocks + num_preallocate_blocks,
self.block_pool.get_num_free_blocks(), self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request. # Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape # This is especially because the block table has the shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment