Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
......@@ -2,10 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""GGUF utility functions."""
from functools import cache
from os import PathLike
from pathlib import Path
import gguf
import regex as re
from gguf.constants import Keys, VisionProjectorType
from gguf.quants import GGMLQuantizationType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger
......@@ -15,6 +19,73 @@ from .repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
@cache
def check_gguf_file(model: str | PathLike) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
try:
with model.open("rb") as f:
header = f.read(4)
return header == b"GGUF"
except Exception as e:
logger.debug("Error reading file %s: %s", model, e)
return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
pattern = r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*:[A-Za-z0-9_+-]+$"
model = str(model)
if re.fullmatch(pattern, model):
_, quant_type = model.rsplit(":", 1)
return is_valid_gguf_quant_type(quant_type)
return False
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
f"Wrong GGUF model or invalid GGUF quant type: {model}.\n"
"- It should be in repo_id:quant_type format.\n"
f"- Valid GGMLQuantizationType values: {GGMLQuantizationType._member_names_}",
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def detect_gguf_multimodal(model: str) -> Path | None:
"""Check if GGUF model has multimodal projector file.
......
......@@ -18,7 +18,8 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
from vllm.transformers_utils.gguf_utils import is_gguf
from vllm.transformers_utils.utils import convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
......
......@@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin):
attention_mask = input_ids.ne(self.pad_id)
text_inputs["attention_mask"] = attention_mask
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids]
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
return_tensors = kwargs.pop("return_tensors", None)
......
......@@ -18,9 +18,7 @@ SUPPORTED_SCHEMES = ["s3://", "gs://"]
try:
from runai_model_streamer import list_safetensors as runai_list_safetensors
from runai_model_streamer import pull_files as runai_pull_files
except (ImportError, OSError):
# see https://github.com/run-ai/runai-model-streamer/issues/26
# OSError will be raised on arm64 platform
except ImportError:
runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment]
runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors")
......
......@@ -9,8 +9,6 @@ from os import PathLike
from pathlib import Path
from typing import Any
from gguf import GGMLQuantizationType
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -29,76 +27,6 @@ def is_cloud_storage(model_or_path: str) -> bool:
return is_s3(model_or_path) or is_gcs(model_or_path)
@cache
def check_gguf_file(model: str | PathLike) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
try:
with model.open("rb") as f:
header = f.read(4)
return header == b"GGUF"
except Exception as e:
logger.debug("Error reading file %s: %s", model, e)
return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
model = str(model)
return (
(not is_cloud_storage(model))
and (not model.startswith(("http://", "https://")))
and ("/" in model and ":" in model)
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
)
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
"- It should be in repo_id:quant_type format.\n"
"- Valid GGMLQuantizationType values: %s",
model,
GGMLQuantizationType._member_names_,
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def modelscope_list_repo_files(
repo_id: str,
revision: str | None = None,
......
......@@ -244,9 +244,8 @@ class FlexibleArgumentParser(ArgumentParser):
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
elif arg.startswith("-O") and arg != "-O":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<optimization_level> here
optimization_level = arg[3:] if arg[2] == "=" else arg[2:]
processed_args += ["--optimization-level", optimization_level]
......@@ -257,17 +256,6 @@ class FlexibleArgumentParser(ArgumentParser):
):
# Convert -O <n> to --optimization-level <n>
processed_args.append("--optimization-level")
elif arg.startswith("-O."):
# Handle -O.* dotted syntax - ALL dotted syntax is deprecated
logger.warning_once(
"The -O.* dotted syntax for --compilation-config is "
"deprecated and will be removed in v0.13.0 or v1.0.0"
", whichever is earlier. Please use -cc.* instead. "
"Example: -cc.backend=eager instead of "
"-O.backend=eager."
)
converted_arg = arg.replace("-O", "-cc", 1)
processed_args.append(converted_arg)
else:
processed_args.append(arg)
......
......@@ -481,8 +481,25 @@ def should_use_deepgemm_for_fp8_linear(
)
def should_use_deepgemm_for_fp8_linear_for_nk(
output_dtype: torch.dtype,
shape0: int,
shape1: int,
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and shape0 % 128 == 0
and shape1 % 128 == 0
)
__all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
......@@ -494,6 +511,7 @@ __all__ = [
"is_deep_gemm_supported",
"get_num_sms",
"should_use_deepgemm_for_fp8_linear",
"should_use_deepgemm_for_fp8_linear_for_nk",
"get_col_major_tma_aligned_tensor",
"get_mk_alignment_for_contiguous_layout",
]
......@@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> bool | None:
"""
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
Return `None` if --attention-config.use_trtllm_attention is not set,
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.attention_config.use_trtllm_attention
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
......@@ -307,7 +302,7 @@ def use_trtllm_attention(
"""Return `True` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it
# CLI argument is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False
......@@ -324,7 +319,7 @@ def use_trtllm_attention(
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"but --attention-config.use_trtllm_attention is set to 1"
)
return False
......@@ -333,7 +328,8 @@ def use_trtllm_attention(
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
)
return False
......@@ -354,7 +350,7 @@ def use_trtllm_attention(
return True
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
# CLI argument not set - use auto-detection
if is_prefill:
# Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto"
......@@ -367,8 +363,10 @@ def use_trtllm_attention(
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
# CLI argument is set to 1 - respect it
logger.info_once(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return True
......@@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
......@@ -526,7 +518,6 @@ __all__ = [
"supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]
......@@ -11,6 +11,17 @@ from typing import Any
import cbor2
try:
# It is important that this remains an optional dependency.
# It would not be allowed in environments with strict security controls,
# so it's best not to have it installed when not in use.
import xxhash as _xxhash
if not hasattr(_xxhash, "xxh3_128_digest"):
_xxhash = None
except ImportError: # pragma: no cover
_xxhash = None
def sha256(input: Any) -> bytes:
"""Hash any picklable Python object using SHA-256.
......@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
return hashlib.sha256(input_bytes).digest()
def _xxhash_digest(input_bytes: bytes) -> bytes:
if _xxhash is None:
raise ModuleNotFoundError(
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
"Install it via `pip install xxhash`."
)
return _xxhash.xxh3_128_digest(input_bytes)
def xxhash(input: Any) -> bytes:
"""Hash picklable objects using xxHash."""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return _xxhash_digest(input_bytes)
def xxhash_cbor(input: Any) -> bytes:
"""Hash objects serialized with CBOR using xxHash."""
input_bytes = cbor2.dumps(input, canonical=True)
return _xxhash_digest(input_bytes)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
"""Get a hash function by name, or raise an error if the function is not found.
......@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256
if hash_fn_name == "sha256_cbor":
return sha256_cbor
if hash_fn_name == "xxhash":
return xxhash
if hash_fn_name == "xxhash_cbor":
return xxhash_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import torch
import torch.cuda.nvtx as nvtx
def print_tensor(tensor_obj, prefix, tensor_list=None):
"""Descends iterators that contains Tensors and prints the Tensor.
Recursive function that descends iterator type arguments until
it finds a Tensor object.
"""
if tensor_list is None:
tensor_list = []
if isinstance(tensor_obj, (list, tuple)):
for ten in tensor_obj:
tensor_list = print_tensor(ten, prefix, tensor_list)
elif isinstance(tensor_obj, torch.Tensor):
tensor_dims = list(tensor_obj.size())
tensor_list.append(tensor_dims)
return tensor_list
def process_layer_params(module_obj):
"""Extract the static parameters from LLM and VLM relevant layer types"""
param_info = {}
# Extract parameters for layers commonly used in LLMs and VLMs
if isinstance(module_obj, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
conv_params = {}
conv_params["in_chan"] = module_obj.in_channels
conv_params["out_chan"] = module_obj.out_channels
conv_params["filter_dim"] = module_obj.kernel_size
conv_params["stride"] = module_obj.stride
conv_params["padding"] = module_obj.padding
conv_params["dilation"] = module_obj.dilation
conv_params["transposed"] = module_obj.transposed
conv_params["output_padding"] = module_obj.output_padding
conv_params["groups"] = module_obj.groups
conv_params["padding_mode"] = module_obj.padding_mode
param_info = conv_params
elif isinstance(
module_obj,
(
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
),
):
convtranspose_params = {}
convtranspose_params["in_chan"] = module_obj.in_channels
convtranspose_params["out_chan"] = module_obj.out_channels
convtranspose_params["filter_dim"] = module_obj.kernel_size
convtranspose_params["stride"] = module_obj.stride
convtranspose_params["padding"] = module_obj.padding
convtranspose_params["dilation"] = module_obj.dilation
convtranspose_params["transposed"] = module_obj.transposed
convtranspose_params["output_padding"] = module_obj.output_padding
convtranspose_params["groups"] = module_obj.groups
convtranspose_params["padding_mode"] = module_obj.padding_mode
param_info = convtranspose_params
elif isinstance(
module_obj, (torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)
):
def _handle_int_or_tuple(parameter):
if isinstance(parameter, tuple):
return list(parameter)
elif isinstance(parameter, int):
return [parameter, parameter]
pooling_params = {}
pooling_params["filter_dim"] = _handle_int_or_tuple(module_obj.kernel_size)
pooling_params["stride"] = _handle_int_or_tuple(module_obj.stride)
pooling_params["padding"] = _handle_int_or_tuple(module_obj.padding)
pooling_params["dilation"] = _handle_int_or_tuple(module_obj.dilation)
param_info = pooling_params
elif isinstance(
module_obj, (torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d)
):
pooling_params = {}
pooling_params["filter_dim"] = [
module_obj.kernel_size,
module_obj.kernel_size,
]
pooling_params["stride"] = [module_obj.stride, module_obj.stride]
pooling_params["padding"] = [module_obj.padding, module_obj.padding]
pooling_params["ceil_mode"] = module_obj.ceil_mode
pooling_params["count_include_pad"] = module_obj.count_include_pad
param_info = pooling_params
elif isinstance(
module_obj,
(
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d,
),
):
pooling_params = {}
pooling_params["output_size"] = [
module_obj.output_size,
module_obj.output_size,
]
param_info = pooling_params
elif isinstance(module_obj, torch.nn.Linear):
param_info["in_features"] = module_obj.in_features
param_info["out_features"] = module_obj.out_features
elif isinstance(
module_obj,
(torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d),
):
param_info["num_features"] = module_obj.num_features
param_info["epsilon"] = module_obj.eps
param_info["momentum"] = module_obj.momentum
elif isinstance(module_obj, torch.nn.ReLU):
param_info["in_place"] = module_obj.inplace
elif isinstance(module_obj, torch.nn.Dropout):
param_info["p"] = module_obj.p
param_info["in_place"] = module_obj.inplace
elif isinstance(module_obj, torch.nn.Embedding):
param_info["num_embeddings"] = module_obj.num_embeddings
param_info["embedding_dim"] = module_obj.embedding_dim
elif isinstance(
module_obj,
(
torch.nn.Upsample,
torch.nn.UpsamplingNearest2d,
torch.nn.UpsamplingBilinear2d,
),
):
param_info["scale_factor"] = module_obj.scale_factor
return param_info
def construct_marker_dict_and_push(
module_name, module_obj, in_tensor, kwargs=None, out_tensor=None
):
marker_dict = {}
marker_dict["Module"] = module_name
## Get trainable parameters like weights and bias
module_params = module_obj.named_parameters(recurse=False)
for idx, (param_name, param_obj) in enumerate(module_params):
if idx == 0:
marker_dict["TrainableParams"] = {}
marker_dict["TrainableParams"][param_name] = list(param_obj.size())
in_tensor_list = print_tensor(in_tensor, "Input")
if in_tensor_list:
marker_dict["Inputs"] = in_tensor_list
out_tensor_list = print_tensor(out_tensor, "Output")
if out_tensor_list:
marker_dict["Outputs"] = out_tensor_list
## Get Kwargs like input_ids and positions for the top module
if kwargs:
for key, value in kwargs.items():
if isinstance(value, (torch.Tensor, list, tuple)):
tensor_list = print_tensor(value, key)
if tensor_list:
marker_dict[key] = tensor_list
param_info = process_layer_params(module_obj)
if param_info:
marker_dict["StaticParams"] = param_info
nvtx.range_push("{}".format(marker_dict))
class ResultHolder:
"""Holder for storing results from within a context manager."""
result = None
@contextmanager
def layerwise_nvtx_marker_context(module_name, module_obj, in_tensor=None, kwargs=None):
"""Context manager for NVTX markers that automatically pushes on enter
and pops on exit.
Example:
with nvtx_marker_context("Module:MyModule", module, in_tensor=args,
kwargs=kwargs) as ctx:
ctx.result = module(*args, **kwargs)
return ctx.result
"""
holder = ResultHolder()
# Push input marker
construct_marker_dict_and_push(
module_name,
module_obj,
in_tensor=in_tensor,
kwargs=kwargs,
)
try:
yield holder
finally:
# Pop input marker
nvtx.range_pop()
# Push and pop output marker
output_name = module_name.replace("(input)", "(output)")
construct_marker_dict_and_push(
output_name,
module_obj,
in_tensor=None,
kwargs=None,
out_tensor=holder.result,
)
nvtx.range_pop()
class PytHooks:
"""This module contains all the code needed to enable forward hooks
in a pytorch network.
To register the hooks for a given network, the user needs to instantiate
a PytHook object. Then call the register_hooks method.
Example:
my_hook = PytHook()
my_hook.register_hooks(my_network_model)
"""
def __init__(self):
"""Initialize module variables."""
super().__init__()
self.module_to_name_map = {}
def _process_layer_params(self, module_obj):
return process_layer_params(module_obj)
def module_fwd_hook(self, module_obj, in_tensor, out_tensor):
"""Callback function that ends the NVTX marker.
Records the module name and tensor information.
Called after the module executes the forward method.
"""
nvtx.range_pop()
module_name = self.module_to_name_map.get(module_obj, "unknown")
construct_marker_dict_and_push(
module_name, module_obj, in_tensor=None, kwargs=None, out_tensor=out_tensor
)
nvtx.range_pop()
return
def module_fwd_pre_hook(self, module_obj, in_tensor, kwargs):
"""Creates an NVTX marker with the module name in it.
This function is called before the module executes.
"""
module_name = self.module_to_name_map.get(module_obj, "unknown")
construct_marker_dict_and_push(
module_name, module_obj, in_tensor=in_tensor, kwargs=kwargs, out_tensor=None
)
return
def register_hooks(self, network_model, module_prefix="top"):
"""User level function that activates all the hooks.
The user needs to call this method from the network source code.
The code descends all the modules in the network and registers their
respective hooks.
"""
# Module types to skip (simple operations that don't need detailed profiling)
skip_types = (
torch.nn.Identity,
torch.nn.Dropout,
torch.nn.Dropout1d,
torch.nn.Dropout2d,
torch.nn.Dropout3d,
)
for name, module in network_model.named_modules(prefix=module_prefix):
# Skip certain module types to reduce profiling overhead
if isinstance(module, skip_types):
continue
module.register_forward_pre_hook(self.module_fwd_pre_hook, with_kwargs=True)
module.register_forward_hook(self.module_fwd_hook)
if module not in self.module_to_name_map:
self.module_to_name_map[module] = name
else:
raise ValueError("Module instance {} is not unique ".format(module))
return
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import math
import sys
from dataclasses import dataclass
from typing import Literal
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
import torch
from typing_extensions import assert_never
from vllm import PoolingRequestOutput
if TYPE_CHECKING:
from vllm import PoolingRequestOutput
else:
PoolingRequestOutput = Any
sys_byteorder = sys.byteorder
......@@ -26,6 +31,14 @@ EMBED_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.float8_e5m2,
}
EMBED_DTYPE_TO_N_BYTES = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"fp8_e4m3": 1,
"fp8_e5m2": 1,
}
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
"float32": torch.float32,
......@@ -49,7 +62,16 @@ ENDIANNESS = ["native", "big", "little"]
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes"]
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
def tensor2base64(x: torch.Tensor) -> str:
with io.BytesIO() as buf:
torch.save(x, buf)
buf.seek(0)
binary_data = buf.read()
return base64.b64encode(binary_data).decode("utf-8")
def tensor2binary(
......@@ -104,7 +126,7 @@ def encode_pooling_output(
elif encoding_format == "base64":
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return base64.b64encode(embedding_bytes).decode("utf-8")
elif encoding_format == "bytes":
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return tensor2binary(output.outputs.data, embed_dtype, endianness)
assert_never(encoding_format)
......@@ -119,6 +141,29 @@ class MetadataItem:
shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
):
n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype]
size = math.prod(shape)
items = [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
return items
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
......
......@@ -204,6 +204,10 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def decorate_logs(process_name: str | None = None) -> None:
"""Decorate stdout/stderr with process name and PID prefix."""
# Respect VLLM_CONFIGURE_LOGGING environment variable
if not envs.VLLM_CONFIGURE_LOGGING:
return
if process_name is None:
process_name = get_mp_context().current_process().name
......
......@@ -28,6 +28,7 @@ else:
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
......
......@@ -8,7 +8,6 @@ from typing import ClassVar
import numpy as np
import torch
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
......@@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config
......@@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
......@@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer"
)
def supports_quant_query_input(self) -> bool:
return True
self.supports_quant_query_input = True
def forward(
self,
......
......@@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
......@@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import (
can_use_trtllm_attention,
flashinfer_disable_q_quantization,
use_trtllm_attention,
)
from vllm.utils.math_utils import cdiv
......@@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
supports_trtllm_attention,
)
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
# Respect explicit disable flag (e.g.,
# --attention-config.use_trtllm_attention=0)
if force_use_trtllm_attention() is False:
return False
......@@ -482,9 +482,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1
self.num_qo_heads = (
self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)
* self.dcp_world_size
self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
)
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
......@@ -501,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_dtype = self.kv_cache_spec.dtype
# Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
# --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if can_use_trtllm and not flashinfer_disable_q_quantization():
if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
):
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
......@@ -1036,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
self.sinks = sinks
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
vllm_config = get_current_vllm_config()
self.supports_quant_query_input = (
self.support_trtllm_attn
and not vllm_config.attention_config.disable_flashinfer_q_quantization
)
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None
......@@ -1047,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
)
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
# FlashInfer requires attention sinks to be float32
def process_weights_after_loading(self, act_dtype: torch.dtype):
if self.sinks is not None and self.sinks.dtype != torch.float32:
......
......@@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
and_masks,
create_block_mask,
flex_attention,
or_masks,
)
from vllm.attention.backends.abstract import (
......@@ -31,6 +32,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (
......@@ -41,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
torch._dynamo.config.recompile_limit = 16
create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead"
)
......@@ -90,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend):
"""FlexAttention supports both decoder and encoder-only attention."""
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@classmethod
def supports_mm_prefix(cls) -> bool:
"""FlexAttention supports full attention for image tokens."""
return True
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
......@@ -315,6 +323,7 @@ class FlexAttentionMetadata:
kv_block_size: int = 16
transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@cached_property
def logical_block_ids(self):
......@@ -442,6 +451,45 @@ class FlexAttentionMetadata:
return final_mask_mod if self.causal else sliding_window_mask_mod
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
"""Creates the prefix LM mask_mod function for FlexAttention."""
assert self.doc_ids is not None
request_lookup = self.doc_ids
def prefix_lm_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
cu_q_idx: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
):
mask = torch.zeros_like(q_idx, dtype=torch.bool)
for req, doc_range_lst in (self.mm_prefix_range or {}).items():
req_mask = request_lookup[cu_q_idx] == req
for start, end in doc_range_lst:
doc_mask_q = (q_idx >= start) & (q_idx <= end)
doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
return mask
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx, logical_kv_idx) = (
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
)
return torch.where(
is_valid,
prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
......@@ -455,6 +503,10 @@ class FlexAttentionMetadata:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
if self.mm_prefix_range:
# Add prefix LM mask for vision-language prefix LM attention
prefix_lm_mask_mod = self.get_prefix_lm_mask_mod()
mask_mod = or_masks(mask_mod, prefix_lm_mask_mod)
return mask_mod
def get_transformed_score_mod(self) -> _score_mod_signature | None:
......@@ -708,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl):
sliding_window: int | None
alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
def __init__(
self,
......@@ -809,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
needs_rebuild_block_mask = False
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
self.mm_prefix_range = attn_metadata.mm_prefix_range
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if needs_rebuild_block_mask:
if attn_metadata.direct_build and attn_metadata.causal:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
......@@ -927,7 +990,18 @@ def get_kernel_options(
if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
# ROCm doesn't expose shared_memory_per_block_optin attribute
# AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup
if hasattr(device_props, "shared_memory_per_block_optin"):
max_shared_memory = device_props.shared_memory_per_block_optin
elif current_platform.is_rocm():
# ROCm fallback: use 64KB
max_shared_memory = 65536
else:
raise RuntimeError(
"Unable to determine shared memory size on this hardware."
)
if max_shared_memory < 144 * 1024:
block_m_candidate = ensure_divisible(
max(1, block_m_candidate // 2), block_m
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment