"vscode:/vscode.git/clone" did not exist on "65334ef3b9e4fd32ebc5c4e512debc25d5025488"
Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -2,10 +2,14 @@ ...@@ -2,10 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""GGUF utility functions.""" """GGUF utility functions."""
from functools import cache
from os import PathLike
from pathlib import Path from pathlib import Path
import gguf import gguf
import regex as re
from gguf.constants import Keys, VisionProjectorType from gguf.constants import Keys, VisionProjectorType
from gguf.quants import GGMLQuantizationType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -15,6 +19,73 @@ from .repo_utils import list_filtered_repo_files ...@@ -15,6 +19,73 @@ from .repo_utils import list_filtered_repo_files
logger = init_logger(__name__) logger = init_logger(__name__)
@cache
def check_gguf_file(model: str | PathLike) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
try:
with model.open("rb") as f:
header = f.read(4)
return header == b"GGUF"
except Exception as e:
logger.debug("Error reading file %s: %s", model, e)
return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
pattern = r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*:[A-Za-z0-9_+-]+$"
model = str(model)
if re.fullmatch(pattern, model):
_, quant_type = model.rsplit(":", 1)
return is_valid_gguf_quant_type(quant_type)
return False
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
f"Wrong GGUF model or invalid GGUF quant type: {model}.\n"
"- It should be in repo_id:quant_type format.\n"
f"- Valid GGMLQuantizationType values: {GGMLQuantizationType._member_names_}",
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def detect_gguf_multimodal(model: str) -> Path | None: def detect_gguf_multimodal(model: str) -> Path | None:
"""Check if GGUF model has multimodal projector file. """Check if GGUF model has multimodal projector file.
......
...@@ -18,7 +18,8 @@ from transformers.processing_utils import ProcessorMixin ...@@ -18,7 +18,8 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf from vllm.transformers_utils.gguf_utils import is_gguf
from vllm.transformers_utils.utils import convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin): ...@@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin):
attention_mask = input_ids.ne(self.pad_id) attention_mask = input_ids.ne(self.pad_id)
text_inputs["attention_mask"] = attention_mask text_inputs["attention_mask"] = attention_mask
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)] text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids]
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]] # image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
return_tensors = kwargs.pop("return_tensors", None) return_tensors = kwargs.pop("return_tensors", None)
......
...@@ -18,9 +18,7 @@ SUPPORTED_SCHEMES = ["s3://", "gs://"] ...@@ -18,9 +18,7 @@ SUPPORTED_SCHEMES = ["s3://", "gs://"]
try: try:
from runai_model_streamer import list_safetensors as runai_list_safetensors from runai_model_streamer import list_safetensors as runai_list_safetensors
from runai_model_streamer import pull_files as runai_pull_files from runai_model_streamer import pull_files as runai_pull_files
except (ImportError, OSError): except ImportError:
# see https://github.com/run-ai/runai-model-streamer/issues/26
# OSError will be raised on arm64 platform
runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment]
runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors") runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors")
......
...@@ -9,8 +9,6 @@ from os import PathLike ...@@ -9,8 +9,6 @@ from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from gguf import GGMLQuantizationType
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -29,76 +27,6 @@ def is_cloud_storage(model_or_path: str) -> bool: ...@@ -29,76 +27,6 @@ def is_cloud_storage(model_or_path: str) -> bool:
return is_s3(model_or_path) or is_gcs(model_or_path) return is_s3(model_or_path) or is_gcs(model_or_path)
@cache
def check_gguf_file(model: str | PathLike) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True
try:
with model.open("rb") as f:
header = f.read(4)
return header == b"GGUF"
except Exception as e:
logger.debug("Error reading file %s: %s", model, e)
return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
model = str(model)
return (
(not is_cloud_storage(model))
and (not model.startswith(("http://", "https://")))
and ("/" in model and ":" in model)
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
)
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
"- It should be in repo_id:quant_type format.\n"
"- Valid GGMLQuantizationType values: %s",
model,
GGMLQuantizationType._member_names_,
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def modelscope_list_repo_files( def modelscope_list_repo_files(
repo_id: str, repo_id: str,
revision: str | None = None, revision: str | None = None,
......
...@@ -244,9 +244,8 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -244,9 +244,8 @@ class FlexibleArgumentParser(ArgumentParser):
else: else:
key = pattern.sub(repl, arg, count=1) key = pattern.sub(repl, arg, count=1)
processed_args.append(key) processed_args.append(key)
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": elif arg.startswith("-O") and arg != "-O":
# allow -O flag to be used without space, e.g. -O3 or -Odecode # allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<optimization_level> here # also handle -O=<optimization_level> here
optimization_level = arg[3:] if arg[2] == "=" else arg[2:] optimization_level = arg[3:] if arg[2] == "=" else arg[2:]
processed_args += ["--optimization-level", optimization_level] processed_args += ["--optimization-level", optimization_level]
...@@ -257,17 +256,6 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -257,17 +256,6 @@ class FlexibleArgumentParser(ArgumentParser):
): ):
# Convert -O <n> to --optimization-level <n> # Convert -O <n> to --optimization-level <n>
processed_args.append("--optimization-level") processed_args.append("--optimization-level")
elif arg.startswith("-O."):
# Handle -O.* dotted syntax - ALL dotted syntax is deprecated
logger.warning_once(
"The -O.* dotted syntax for --compilation-config is "
"deprecated and will be removed in v0.13.0 or v1.0.0"
", whichever is earlier. Please use -cc.* instead. "
"Example: -cc.backend=eager instead of "
"-O.backend=eager."
)
converted_arg = arg.replace("-O", "-cc", 1)
processed_args.append(converted_arg)
else: else:
processed_args.append(arg) processed_args.append(arg)
......
...@@ -481,8 +481,25 @@ def should_use_deepgemm_for_fp8_linear( ...@@ -481,8 +481,25 @@ def should_use_deepgemm_for_fp8_linear(
) )
def should_use_deepgemm_for_fp8_linear_for_nk(
output_dtype: torch.dtype,
shape0: int,
shape1: int,
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and shape0 % 128 == 0
and shape1 % 128 == 0
)
__all__ = [ __all__ = [
"calc_diff", "calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt", "fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous", "m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked", "fp8_m_grouped_gemm_nt_masked",
...@@ -494,6 +511,7 @@ __all__ = [ ...@@ -494,6 +511,7 @@ __all__ = [
"is_deep_gemm_supported", "is_deep_gemm_supported",
"get_num_sms", "get_num_sms",
"should_use_deepgemm_for_fp8_linear", "should_use_deepgemm_for_fp8_linear",
"should_use_deepgemm_for_fp8_linear_for_nk",
"get_col_major_tma_aligned_tensor", "get_col_major_tma_aligned_tensor",
"get_mk_alignment_for_contiguous_layout", "get_mk_alignment_for_contiguous_layout",
] ]
...@@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool: ...@@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
return current_platform.is_device_capability(100) and has_nvidia_artifactory() return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> bool | None: def force_use_trtllm_attention() -> bool | None:
""" """
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set, Return `None` if --attention-config.use_trtllm_attention is not set,
return `True` if TRTLLM attention is forced to be used, return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used. return `False` if TRTLLM attention is forced to be not used.
""" """
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.attention_config.use_trtllm_attention
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
...@@ -307,7 +302,7 @@ def use_trtllm_attention( ...@@ -307,7 +302,7 @@ def use_trtllm_attention(
"""Return `True` if TRTLLM attention is used.""" """Return `True` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention() force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it # CLI argument is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm: if force_use_trtllm is not None and not force_use_trtllm:
return False return False
...@@ -324,7 +319,7 @@ def use_trtllm_attention( ...@@ -324,7 +319,7 @@ def use_trtllm_attention(
if force_use_trtllm: if force_use_trtllm:
logger.warning_once( logger.warning_once(
"TRTLLM attention is not supported on this platform, " "TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1" "but --attention-config.use_trtllm_attention is set to 1"
) )
return False return False
...@@ -333,7 +328,8 @@ def use_trtllm_attention( ...@@ -333,7 +328,8 @@ def use_trtllm_attention(
if force_use_trtllm: if force_use_trtllm:
logger.warning_once( logger.warning_once(
"TRTLLM attention is not supported for this combination of " "TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" "query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
) )
return False return False
...@@ -354,7 +350,7 @@ def use_trtllm_attention( ...@@ -354,7 +350,7 @@ def use_trtllm_attention(
return True return True
if force_use_trtllm is None: if force_use_trtllm is None:
# Environment variable not set - use auto-detection # CLI argument not set - use auto-detection
if is_prefill: if is_prefill:
# Prefill auto-detection # Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto" use_trtllm = kv_cache_dtype == "auto"
...@@ -367,8 +363,10 @@ def use_trtllm_attention( ...@@ -367,8 +363,10 @@ def use_trtllm_attention(
logger.warning_once("Using TRTLLM decode attention (auto-detected).") logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm return use_trtllm
# Environment variable is set to 1 - respect it # CLI argument is set to 1 - respect it
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") logger.info_once(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return True return True
...@@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm( ...@@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
return output return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [ __all__ = [
"has_flashinfer", "has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_trtllm_fp8_block_scale_moe",
...@@ -526,7 +518,6 @@ __all__ = [ ...@@ -526,7 +518,6 @@ __all__ = [
"supports_trtllm_attention", "supports_trtllm_attention",
"can_use_trtllm_attention", "can_use_trtllm_attention",
"use_trtllm_attention", "use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm", "flashinfer_scaled_fp8_mm",
] ]
...@@ -11,6 +11,17 @@ from typing import Any ...@@ -11,6 +11,17 @@ from typing import Any
import cbor2 import cbor2
try:
# It is important that this remains an optional dependency.
# It would not be allowed in environments with strict security controls,
# so it's best not to have it installed when not in use.
import xxhash as _xxhash
if not hasattr(_xxhash, "xxh3_128_digest"):
_xxhash = None
except ImportError: # pragma: no cover
_xxhash = None
def sha256(input: Any) -> bytes: def sha256(input: Any) -> bytes:
"""Hash any picklable Python object using SHA-256. """Hash any picklable Python object using SHA-256.
...@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes: ...@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
return hashlib.sha256(input_bytes).digest() return hashlib.sha256(input_bytes).digest()
def _xxhash_digest(input_bytes: bytes) -> bytes:
if _xxhash is None:
raise ModuleNotFoundError(
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
"Install it via `pip install xxhash`."
)
return _xxhash.xxh3_128_digest(input_bytes)
def xxhash(input: Any) -> bytes:
"""Hash picklable objects using xxHash."""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return _xxhash_digest(input_bytes)
def xxhash_cbor(input: Any) -> bytes:
"""Hash objects serialized with CBOR using xxHash."""
input_bytes = cbor2.dumps(input, canonical=True)
return _xxhash_digest(input_bytes)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
"""Get a hash function by name, or raise an error if the function is not found. """Get a hash function by name, or raise an error if the function is not found.
...@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: ...@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256 return sha256
if hash_fn_name == "sha256_cbor": if hash_fn_name == "sha256_cbor":
return sha256_cbor return sha256_cbor
if hash_fn_name == "xxhash":
return xxhash
if hash_fn_name == "xxhash_cbor":
return xxhash_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}") raise ValueError(f"Unsupported hash function: {hash_fn_name}")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import torch
import torch.cuda.nvtx as nvtx
def print_tensor(tensor_obj, prefix, tensor_list=None):
"""Descends iterators that contains Tensors and prints the Tensor.
Recursive function that descends iterator type arguments until
it finds a Tensor object.
"""
if tensor_list is None:
tensor_list = []
if isinstance(tensor_obj, (list, tuple)):
for ten in tensor_obj:
tensor_list = print_tensor(ten, prefix, tensor_list)
elif isinstance(tensor_obj, torch.Tensor):
tensor_dims = list(tensor_obj.size())
tensor_list.append(tensor_dims)
return tensor_list
def process_layer_params(module_obj):
"""Extract the static parameters from LLM and VLM relevant layer types"""
param_info = {}
# Extract parameters for layers commonly used in LLMs and VLMs
if isinstance(module_obj, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
conv_params = {}
conv_params["in_chan"] = module_obj.in_channels
conv_params["out_chan"] = module_obj.out_channels
conv_params["filter_dim"] = module_obj.kernel_size
conv_params["stride"] = module_obj.stride
conv_params["padding"] = module_obj.padding
conv_params["dilation"] = module_obj.dilation
conv_params["transposed"] = module_obj.transposed
conv_params["output_padding"] = module_obj.output_padding
conv_params["groups"] = module_obj.groups
conv_params["padding_mode"] = module_obj.padding_mode
param_info = conv_params
elif isinstance(
module_obj,
(
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
),
):
convtranspose_params = {}
convtranspose_params["in_chan"] = module_obj.in_channels
convtranspose_params["out_chan"] = module_obj.out_channels
convtranspose_params["filter_dim"] = module_obj.kernel_size
convtranspose_params["stride"] = module_obj.stride
convtranspose_params["padding"] = module_obj.padding
convtranspose_params["dilation"] = module_obj.dilation
convtranspose_params["transposed"] = module_obj.transposed
convtranspose_params["output_padding"] = module_obj.output_padding
convtranspose_params["groups"] = module_obj.groups
convtranspose_params["padding_mode"] = module_obj.padding_mode
param_info = convtranspose_params
elif isinstance(
module_obj, (torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)
):
def _handle_int_or_tuple(parameter):
if isinstance(parameter, tuple):
return list(parameter)
elif isinstance(parameter, int):
return [parameter, parameter]
pooling_params = {}
pooling_params["filter_dim"] = _handle_int_or_tuple(module_obj.kernel_size)
pooling_params["stride"] = _handle_int_or_tuple(module_obj.stride)
pooling_params["padding"] = _handle_int_or_tuple(module_obj.padding)
pooling_params["dilation"] = _handle_int_or_tuple(module_obj.dilation)
param_info = pooling_params
elif isinstance(
module_obj, (torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d)
):
pooling_params = {}
pooling_params["filter_dim"] = [
module_obj.kernel_size,
module_obj.kernel_size,
]
pooling_params["stride"] = [module_obj.stride, module_obj.stride]
pooling_params["padding"] = [module_obj.padding, module_obj.padding]
pooling_params["ceil_mode"] = module_obj.ceil_mode
pooling_params["count_include_pad"] = module_obj.count_include_pad
param_info = pooling_params
elif isinstance(
module_obj,
(
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d,
),
):
pooling_params = {}
pooling_params["output_size"] = [
module_obj.output_size,
module_obj.output_size,
]
param_info = pooling_params
elif isinstance(module_obj, torch.nn.Linear):
param_info["in_features"] = module_obj.in_features
param_info["out_features"] = module_obj.out_features
elif isinstance(
module_obj,
(torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d),
):
param_info["num_features"] = module_obj.num_features
param_info["epsilon"] = module_obj.eps
param_info["momentum"] = module_obj.momentum
elif isinstance(module_obj, torch.nn.ReLU):
param_info["in_place"] = module_obj.inplace
elif isinstance(module_obj, torch.nn.Dropout):
param_info["p"] = module_obj.p
param_info["in_place"] = module_obj.inplace
elif isinstance(module_obj, torch.nn.Embedding):
param_info["num_embeddings"] = module_obj.num_embeddings
param_info["embedding_dim"] = module_obj.embedding_dim
elif isinstance(
module_obj,
(
torch.nn.Upsample,
torch.nn.UpsamplingNearest2d,
torch.nn.UpsamplingBilinear2d,
),
):
param_info["scale_factor"] = module_obj.scale_factor
return param_info
def construct_marker_dict_and_push(
module_name, module_obj, in_tensor, kwargs=None, out_tensor=None
):
marker_dict = {}
marker_dict["Module"] = module_name
## Get trainable parameters like weights and bias
module_params = module_obj.named_parameters(recurse=False)
for idx, (param_name, param_obj) in enumerate(module_params):
if idx == 0:
marker_dict["TrainableParams"] = {}
marker_dict["TrainableParams"][param_name] = list(param_obj.size())
in_tensor_list = print_tensor(in_tensor, "Input")
if in_tensor_list:
marker_dict["Inputs"] = in_tensor_list
out_tensor_list = print_tensor(out_tensor, "Output")
if out_tensor_list:
marker_dict["Outputs"] = out_tensor_list
## Get Kwargs like input_ids and positions for the top module
if kwargs:
for key, value in kwargs.items():
if isinstance(value, (torch.Tensor, list, tuple)):
tensor_list = print_tensor(value, key)
if tensor_list:
marker_dict[key] = tensor_list
param_info = process_layer_params(module_obj)
if param_info:
marker_dict["StaticParams"] = param_info
nvtx.range_push("{}".format(marker_dict))
class ResultHolder:
"""Holder for storing results from within a context manager."""
result = None
@contextmanager
def layerwise_nvtx_marker_context(module_name, module_obj, in_tensor=None, kwargs=None):
"""Context manager for NVTX markers that automatically pushes on enter
and pops on exit.
Example:
with nvtx_marker_context("Module:MyModule", module, in_tensor=args,
kwargs=kwargs) as ctx:
ctx.result = module(*args, **kwargs)
return ctx.result
"""
holder = ResultHolder()
# Push input marker
construct_marker_dict_and_push(
module_name,
module_obj,
in_tensor=in_tensor,
kwargs=kwargs,
)
try:
yield holder
finally:
# Pop input marker
nvtx.range_pop()
# Push and pop output marker
output_name = module_name.replace("(input)", "(output)")
construct_marker_dict_and_push(
output_name,
module_obj,
in_tensor=None,
kwargs=None,
out_tensor=holder.result,
)
nvtx.range_pop()
class PytHooks:
"""This module contains all the code needed to enable forward hooks
in a pytorch network.
To register the hooks for a given network, the user needs to instantiate
a PytHook object. Then call the register_hooks method.
Example:
my_hook = PytHook()
my_hook.register_hooks(my_network_model)
"""
def __init__(self):
"""Initialize module variables."""
super().__init__()
self.module_to_name_map = {}
def _process_layer_params(self, module_obj):
return process_layer_params(module_obj)
def module_fwd_hook(self, module_obj, in_tensor, out_tensor):
"""Callback function that ends the NVTX marker.
Records the module name and tensor information.
Called after the module executes the forward method.
"""
nvtx.range_pop()
module_name = self.module_to_name_map.get(module_obj, "unknown")
construct_marker_dict_and_push(
module_name, module_obj, in_tensor=None, kwargs=None, out_tensor=out_tensor
)
nvtx.range_pop()
return
def module_fwd_pre_hook(self, module_obj, in_tensor, kwargs):
"""Creates an NVTX marker with the module name in it.
This function is called before the module executes.
"""
module_name = self.module_to_name_map.get(module_obj, "unknown")
construct_marker_dict_and_push(
module_name, module_obj, in_tensor=in_tensor, kwargs=kwargs, out_tensor=None
)
return
def register_hooks(self, network_model, module_prefix="top"):
"""User level function that activates all the hooks.
The user needs to call this method from the network source code.
The code descends all the modules in the network and registers their
respective hooks.
"""
# Module types to skip (simple operations that don't need detailed profiling)
skip_types = (
torch.nn.Identity,
torch.nn.Dropout,
torch.nn.Dropout1d,
torch.nn.Dropout2d,
torch.nn.Dropout3d,
)
for name, module in network_model.named_modules(prefix=module_prefix):
# Skip certain module types to reduce profiling overhead
if isinstance(module, skip_types):
continue
module.register_forward_pre_hook(self.module_fwd_pre_hook, with_kwargs=True)
module.register_forward_hook(self.module_fwd_hook)
if module not in self.module_to_name_map:
self.module_to_name_map[module] = name
else:
raise ValueError("Module instance {} is not unique ".format(module))
return
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64 import base64
import io
import math
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import TYPE_CHECKING, Any, Literal
import numpy as np import numpy as np
import torch import torch
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm import PoolingRequestOutput if TYPE_CHECKING:
from vllm import PoolingRequestOutput
else:
PoolingRequestOutput = Any
sys_byteorder = sys.byteorder sys_byteorder = sys.byteorder
...@@ -26,6 +31,14 @@ EMBED_DTYPE_TO_TORCH_DTYPE = { ...@@ -26,6 +31,14 @@ EMBED_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.float8_e5m2, "fp8_e5m2": torch.float8_e5m2,
} }
EMBED_DTYPE_TO_N_BYTES = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"fp8_e4m3": 1,
"fp8_e5m2": 1,
}
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = { EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
"float32": torch.float32, "float32": torch.float32,
...@@ -49,7 +62,16 @@ ENDIANNESS = ["native", "big", "little"] ...@@ -49,7 +62,16 @@ ENDIANNESS = ["native", "big", "little"]
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"] EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"] Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes"] EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
def tensor2base64(x: torch.Tensor) -> str:
with io.BytesIO() as buf:
torch.save(x, buf)
buf.seek(0)
binary_data = buf.read()
return base64.b64encode(binary_data).decode("utf-8")
def tensor2binary( def tensor2binary(
...@@ -104,7 +126,7 @@ def encode_pooling_output( ...@@ -104,7 +126,7 @@ def encode_pooling_output(
elif encoding_format == "base64": elif encoding_format == "base64":
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness) embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return base64.b64encode(embedding_bytes).decode("utf-8") return base64.b64encode(embedding_bytes).decode("utf-8")
elif encoding_format == "bytes": elif encoding_format == "bytes" or encoding_format == "bytes_only":
return tensor2binary(output.outputs.data, embed_dtype, endianness) return tensor2binary(output.outputs.data, embed_dtype, endianness)
assert_never(encoding_format) assert_never(encoding_format)
...@@ -119,6 +141,29 @@ class MetadataItem: ...@@ -119,6 +141,29 @@ class MetadataItem:
shape: tuple[int, ...] shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
):
n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype]
size = math.prod(shape)
items = [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
return items
def encode_pooling_bytes( def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput], pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType, embed_dtype: EmbedDType,
......
...@@ -204,6 +204,10 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: ...@@ -204,6 +204,10 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def decorate_logs(process_name: str | None = None) -> None: def decorate_logs(process_name: str | None = None) -> None:
"""Decorate stdout/stderr with process name and PID prefix.""" """Decorate stdout/stderr with process name and PID prefix."""
# Respect VLLM_CONFIGURE_LOGGING environment variable
if not envs.VLLM_CONFIGURE_LOGGING:
return
if process_name is None: if process_name is None:
process_name = get_mp_context().current_process().name process_name = get_mp_context().current_process().name
......
...@@ -28,6 +28,7 @@ else: ...@@ -28,6 +28,7 @@ else:
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32, "float32": torch.float32,
"half": torch.half, "half": torch.half,
"float16": torch.float16,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
......
...@@ -8,7 +8,6 @@ from typing import ClassVar ...@@ -8,7 +8,6 @@ from typing import ClassVar
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
...@@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config self.parallel_config
...@@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# When using cuda graph, we need to set the upper bound of the # When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are # number of splits so that large enough intermediate buffers are
# pre-allocated during capture. # pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH self.max_num_splits = (
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
...@@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer" "heads in the layer"
) )
def supports_quant_query_input(self) -> bool: self.supports_quant_query_input = True
return True
def forward( def forward(
self, self,
......
...@@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import ( ...@@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
) )
from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability ...@@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
can_use_trtllm_attention, can_use_trtllm_attention,
flashinfer_disable_q_quantization,
use_trtllm_attention, use_trtllm_attention,
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend): ...@@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
supports_trtllm_attention, supports_trtllm_attention,
) )
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0) # Respect explicit disable flag (e.g.,
# --attention-config.use_trtllm_attention=0)
if force_use_trtllm_attention() is False: if force_use_trtllm_attention() is False:
return False return False
...@@ -482,9 +482,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -482,9 +482,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1 self.dcp_kv_cache_interleave_size = 1
self.num_qo_heads = ( self.num_qo_heads = self.model_config.get_num_attention_heads(
self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) self.vllm_config.parallel_config
* self.dcp_world_size
) )
self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.num_kv_heads = self.kv_cache_spec.num_kv_heads
...@@ -501,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -501,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_dtype = self.kv_cache_spec.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype
# Use model dtype as q dtype when TRTLLM attn is not supported, or # Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
# use fp8 q if kv cache is fp8, and will fall back to model dtype # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata # if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if can_use_trtllm and not flashinfer_disable_q_quantization(): if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
):
self.q_data_type = self.kv_cache_dtype self.q_data_type = self.kv_cache_dtype
else: else:
self.q_data_type = self.model_config.dtype self.q_data_type = self.model_config.dtype
...@@ -1036,6 +1038,11 @@ class FlashInferImpl(AttentionImpl): ...@@ -1036,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
self.sinks = sinks self.sinks = sinks
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
vllm_config = get_current_vllm_config()
self.supports_quant_query_input = (
self.support_trtllm_attn
and not vllm_config.attention_config.disable_flashinfer_q_quantization
)
self.bmm1_scale: float | None = None self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None self.o_sf_scale: float | None = None
...@@ -1047,12 +1054,6 @@ class FlashInferImpl(AttentionImpl): ...@@ -1047,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
) )
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
# FlashInfer requires attention sinks to be float32 # FlashInfer requires attention sinks to be float32
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
if self.sinks is not None and self.sinks.dtype != torch.float32: if self.sinks is not None and self.sinks.dtype != torch.float32:
......
...@@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import ( ...@@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
and_masks, and_masks,
create_block_mask, create_block_mask,
flex_attention, flex_attention,
or_masks,
) )
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
...@@ -31,6 +32,7 @@ from vllm.logger import init_logger ...@@ -31,6 +32,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
...@@ -41,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec ...@@ -41,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
torch._dynamo.config.recompile_limit = 16
create_block_mask_compiled = torch.compile( create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead" create_block_mask, fullgraph=True, mode="reduce-overhead"
) )
...@@ -90,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -90,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend):
"""FlexAttention supports both decoder and encoder-only attention.""" """FlexAttention supports both decoder and encoder-only attention."""
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@classmethod
def supports_mm_prefix(cls) -> bool:
"""FlexAttention supports full attention for image tokens."""
return True
@staticmethod @staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]: def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl return FlexAttentionImpl
...@@ -315,6 +323,7 @@ class FlexAttentionMetadata: ...@@ -315,6 +323,7 @@ class FlexAttentionMetadata:
kv_block_size: int = 16 kv_block_size: int = 16
transformed_score_mod: _score_mod_signature | None = None transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None sliding_window: int | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@cached_property @cached_property
def logical_block_ids(self): def logical_block_ids(self):
...@@ -442,6 +451,45 @@ class FlexAttentionMetadata: ...@@ -442,6 +451,45 @@ class FlexAttentionMetadata:
return final_mask_mod if self.causal else sliding_window_mask_mod return final_mask_mod if self.causal else sliding_window_mask_mod
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
"""Creates the prefix LM mask_mod function for FlexAttention."""
assert self.doc_ids is not None
request_lookup = self.doc_ids
def prefix_lm_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
cu_q_idx: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
):
mask = torch.zeros_like(q_idx, dtype=torch.bool)
for req, doc_range_lst in (self.mm_prefix_range or {}).items():
req_mask = request_lookup[cu_q_idx] == req
for start, end in doc_range_lst:
doc_mask_q = (q_idx >= start) & (q_idx <= end)
doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
return mask
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx, logical_kv_idx) = (
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
)
return torch.where(
is_valid,
prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod
def get_mask_mod(self): def get_mask_mod(self):
# Stage-1: initialize the base mask_mod # Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder) # (causal mask for decoder or bidirectional mask for encoder)
...@@ -455,6 +503,10 @@ class FlexAttentionMetadata: ...@@ -455,6 +503,10 @@ class FlexAttentionMetadata:
# Add sliding window mask for sliding window attention # Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod() sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod) mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
if self.mm_prefix_range:
# Add prefix LM mask for vision-language prefix LM attention
prefix_lm_mask_mod = self.get_prefix_lm_mask_mod()
mask_mod = or_masks(mask_mod, prefix_lm_mask_mod)
return mask_mod return mask_mod
def get_transformed_score_mod(self) -> _score_mod_signature | None: def get_transformed_score_mod(self) -> _score_mod_signature | None:
...@@ -708,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -708,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl):
sliding_window: int | None sliding_window: int | None
alibi_slopes: torch.Tensor | None alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None logits_soft_cap: float | None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
def __init__( def __init__(
self, self,
...@@ -809,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -809,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
needs_rebuild_block_mask = False
if attn_metadata.sliding_window != self.sliding_window: if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build: if attn_metadata.direct_build:
# update mask mod in attention metadata # update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod() attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
self.mm_prefix_range = attn_metadata.mm_prefix_range
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if needs_rebuild_block_mask:
if attn_metadata.direct_build and attn_metadata.causal:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct() attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else: else:
attn_metadata.block_mask = attn_metadata.build_block_mask() attn_metadata.block_mask = attn_metadata.build_block_mask()
...@@ -927,7 +990,18 @@ def get_kernel_options( ...@@ -927,7 +990,18 @@ def get_kernel_options(
if torch.cuda.is_available(): if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties() device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin # ROCm doesn't expose shared_memory_per_block_optin attribute
# AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup
if hasattr(device_props, "shared_memory_per_block_optin"):
max_shared_memory = device_props.shared_memory_per_block_optin
elif current_platform.is_rocm():
# ROCm fallback: use 64KB
max_shared_memory = 65536
else:
raise RuntimeError(
"Unable to determine shared memory size on this hardware."
)
if max_shared_memory < 144 * 1024: if max_shared_memory < 144 * 1024:
block_m_candidate = ensure_divisible( block_m_candidate = ensure_divisible(
max(1, block_m_candidate // 2), block_m max(1, block_m_candidate // 2), block_m
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment