Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
# [num_tokens, num_query_heads, num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
......@@ -749,6 +749,11 @@ def unified_attention(
q_descale,
k_descale,
v_descale,
seq_threshold_3D=None,
num_par_softmax_segments=None,
softmax_segm_output=None,
softmax_segm_max=None,
softmax_segm_expsum=None,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
......@@ -793,8 +798,19 @@ def unified_attention(
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
if (
seq_threshold_3D is None
or num_par_softmax_segments is None
or softmax_segm_output is None
or softmax_segm_max is None
or softmax_segm_expsum is None
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
):
kernel_unified_attention_2d[
(
total_num_q_blocks,
......@@ -847,37 +863,12 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16
segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
kernel_unified_attention_3d[
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
](
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
......@@ -917,13 +908,13 @@ def unified_attention(
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
)
reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
......@@ -936,6 +927,6 @@ def unified_attention(
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None,
)
......@@ -16,6 +16,7 @@ import einops
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
......@@ -44,9 +45,7 @@ def flash_attn_maxseqlen_wrapper(
dropout_p=0.0,
causal=False,
)
context_layer = einops.rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
......@@ -59,8 +58,7 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int,
is_rocm_aiter: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
return torch.empty_like(q)
direct_register_custom_op(
......@@ -92,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
......@@ -106,7 +111,6 @@ def torch_sdpa_wrapper(
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer
......@@ -116,8 +120,7 @@ def torch_sdpa_wrapper_fake(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
return torch.empty_like(q)
direct_register_custom_op(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from functools import cache
from typing import cast, get_args
from typing import NamedTuple, cast, get_args
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
......@@ -19,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
......@@ -44,8 +68,7 @@ def get_attn_backend(
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
......@@ -54,58 +77,26 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
)
@cache
def _cached_get_attn_backend(
backend,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
attn_selector_config: AttentionSelectorConfig,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls)
if "use_v1" in sig.parameters:
logger.warning_once(
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True, # use_v1
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
attn_selector_config=attn_selector_config,
)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
......
......@@ -235,7 +235,9 @@ async def get_request(
def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float]
outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float],
) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests.
......@@ -788,7 +790,7 @@ async def benchmark(
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
"Total token throughput (tok/s):", metrics.total_token_throughput
)
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark the cold and warm startup time of vLLM models.
This script measures total startup time (including model loading, compilation,
and cache operations) for both cold and warm scenarios:
- Cold startup: Fresh start with no caches (temporary cache directories)
- Warm startup: Using cached compilation and model info
"""
import argparse
import dataclasses
import json
import multiprocessing
import os
import shutil
import tempfile
import time
from contextlib import contextmanager
from typing import Any
import numpy as np
from tqdm import tqdm
from vllm.benchmarks.lib.utils import (
convert_to_pytorch_benchmark_format,
write_to_json,
)
from vllm.engine.arg_utils import EngineArgs
@contextmanager
def cold_startup():
"""
Context manager to measure cold startup time:
1. Uses a temporary directory for vLLM cache to avoid any pollution
between cold startup iterations.
2. Uses inductor's fresh_cache to clear torch.compile caches.
"""
from torch._inductor.utils import fresh_cache
# Use temporary directory for caching to avoid any pollution between cold startups
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
try:
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
with fresh_cache():
yield
finally:
# Clean up temporary cache directory
shutil.rmtree(temp_cache_dir, ignore_errors=True)
if original_cache_root:
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
else:
os.environ.pop("VLLM_CACHE_ROOT", None)
def run_startup_in_subprocess(engine_args_dict, result_queue):
"""
Run LLM startup in a subprocess and return timing metrics via a queue.
This ensures complete isolation between iterations.
"""
try:
# Import inside the subprocess to avoid issues with forking
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
engine_args = EngineArgs(**engine_args_dict)
# Measure total startup time
start_time = time.perf_counter()
llm = LLM(**dataclasses.asdict(engine_args))
total_startup_time = time.perf_counter() - start_time
# Extract compilation time if available
compilation_time = 0.0
if hasattr(llm.llm_engine, "vllm_config"):
vllm_config = llm.llm_engine.vllm_config
if (
hasattr(vllm_config, "compilation_config")
and vllm_config.compilation_config is not None
):
compilation_time = vllm_config.compilation_config.compilation_time
result_queue.put(
{
"total_startup_time": total_startup_time,
"compilation_time": compilation_time,
}
)
except Exception as e:
result_queue.put(None)
result_queue.put(str(e))
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
base_name = os.path.splitext(args.output_json)[0]
cold_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_startup_time": results["avg_cold_startup_time"],
},
extra_info={
"cold_startup_times": results["cold_startup_times"],
"cold_startup_percentiles": results["cold_startup_percentiles"],
},
)
if cold_startup_records:
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
cold_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_compilation_time": results["avg_cold_compilation_time"],
},
extra_info={
"cold_compilation_times": results["cold_compilation_times"],
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
},
)
if cold_compilation_records:
write_to_json(
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
)
warm_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_startup_time": results["avg_warm_startup_time"],
},
extra_info={
"warm_startup_times": results["warm_startup_times"],
"warm_startup_percentiles": results["warm_startup_percentiles"],
},
)
if warm_startup_records:
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
warm_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_compilation_time": results["avg_warm_compilation_time"],
},
extra_info={
"warm_compilation_times": results["warm_compilation_times"],
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
},
)
if warm_compilation_records:
write_to_json(
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-iters-cold",
type=int,
default=5,
help="Number of cold startup iterations.",
)
parser.add_argument(
"--num-iters-warmup",
type=int,
default=3,
help="Number of warmup iterations before benchmarking warm startups.",
)
parser.add_argument(
"--num-iters-warm",
type=int,
default=5,
help="Number of warm startup iterations.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the startup time results in JSON format.",
)
parser = EngineArgs.add_cli_args(parser)
return parser
def main(args: argparse.Namespace):
# Set multiprocessing start method to 'spawn' for clean process isolation
# This ensures each subprocess starts fresh without inheriting state
multiprocessing.set_start_method("spawn", force=True)
engine_args = EngineArgs.from_cli_args(args)
def create_llm_and_measure_startup():
"""
Create LLM instance in a subprocess and measure startup time.
Returns timing metrics, using subprocess for complete isolation.
"""
# Convert engine_args to dictionary for pickling
engine_args_dict = dataclasses.asdict(engine_args)
# Create a queue for inter-process communication
result_queue = multiprocessing.Queue()
process = multiprocessing.Process(
target=run_startup_in_subprocess,
args=(
engine_args_dict,
result_queue,
),
)
process.start()
process.join()
if not result_queue.empty():
result = result_queue.get()
if result is None:
if not result_queue.empty():
error_msg = result_queue.get()
raise RuntimeError(f"Subprocess failed: {error_msg}")
else:
raise RuntimeError("Subprocess failed with unknown error")
return result
else:
raise RuntimeError("Subprocess did not return a result")
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
print("Measuring cold startup time...\n")
cold_startup_times = []
cold_compilation_times = []
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
with cold_startup():
metrics = create_llm_and_measure_startup()
cold_startup_times.append(metrics["total_startup_time"])
cold_compilation_times.append(metrics["compilation_time"])
# Warmup for warm startup
print("\nWarming up for warm startup measurement...\n")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
create_llm_and_measure_startup()
print("\nMeasuring warm startup time...\n")
warm_startup_times = []
warm_compilation_times = []
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
metrics = create_llm_and_measure_startup()
warm_startup_times.append(metrics["total_startup_time"])
warm_compilation_times.append(metrics["compilation_time"])
# Calculate statistics
cold_startup_array = np.array(cold_startup_times)
cold_compilation_array = np.array(cold_compilation_times)
warm_startup_array = np.array(warm_startup_times)
warm_compilation_array = np.array(warm_compilation_times)
avg_cold_startup = np.mean(cold_startup_array)
avg_cold_compilation = np.mean(cold_compilation_array)
avg_warm_startup = np.mean(warm_startup_array)
avg_warm_compilation = np.mean(warm_compilation_array)
percentages = [10, 25, 50, 75, 90, 99]
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
print("\n" + "=" * 60)
print("STARTUP TIME BENCHMARK RESULTS")
print("=" * 60)
# Cold startup statistics
print("\nCOLD STARTUP:")
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, cold_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
# Warm startup statistics
print("\nWARM STARTUP:")
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, warm_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("=" * 60)
# Output JSON results if specified
if args.output_json:
results = {
"avg_cold_startup_time": float(avg_cold_startup),
"avg_cold_compilation_time": float(avg_cold_compilation),
"cold_startup_times": cold_startup_times,
"cold_compilation_times": cold_compilation_times,
"cold_startup_percentiles": dict(
zip(percentages, cold_startup_percentiles.tolist())
),
"cold_compilation_percentiles": dict(
zip(percentages, cold_compilation_percentiles.tolist())
),
"avg_warm_startup_time": float(avg_warm_startup),
"avg_warm_compilation_time": float(avg_warm_compilation),
"warm_startup_times": warm_startup_times,
"warm_compilation_times": warm_compilation_times,
"warm_startup_percentiles": dict(
zip(percentages, warm_startup_percentiles.tolist())
),
"warm_compilation_percentiles": dict(
zip(percentages, warm_compilation_percentiles.tolist())
),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
......@@ -141,7 +141,25 @@ class CompilerManager:
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
self.cache = ast.literal_eval(f.read())
cache = ast.literal_eval(f.read())
def check_type(value, ty):
if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
def parse_key(key: Any) -> tuple[Range, int, str]:
range_tuple, graph_index, compiler_name = key
check_type(graph_index, int)
check_type(compiler_name, str)
if isinstance(range_tuple, tuple):
start, end = range_tuple
check_type(start, int)
check_type(end, int)
range_tuple = Range(start=start, end=end)
check_type(range_tuple, Range)
return range_tuple, graph_index, compiler_name
self.cache = {parse_key(key): value for key, value in cache.items()}
self.compiler.initialize_cache(
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
......@@ -445,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"
model_is_encoder: bool = False
@contextmanager
def set_model_tag(tag: str):
def set_model_tag(tag: str, is_encoder: bool = False):
"""Context manager to set the model tag."""
global model_tag
global model_is_encoder
assert tag != model_tag, (
f"Model tag {tag} is the same as the current tag {model_tag}."
)
old_tag = model_tag
old_is_encoder = model_is_encoder
model_tag = tag
model_is_encoder = is_encoder
try:
yield
finally:
model_tag = old_tag
model_is_encoder = old_is_encoder
class VllmBackend:
......@@ -505,6 +529,9 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
# Mark compilation for encoder.
self.is_encoder = model_is_encoder
# Passes to run on the graph post-grad.
self.pass_manager = resolve_obj_by_qualname(
current_platform.get_pass_manager_cls()
......
......@@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from .monitor import start_monitoring_torch_compile
......@@ -316,7 +316,13 @@ def _support_torch_compile(
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED:
torch._dynamo.decorators.mark_unbacked(arg, dims)
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else:
torch._dynamo.mark_dynamic(arg, dims)
......@@ -350,7 +356,13 @@ def _support_torch_compile(
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.decorators.mark_unbacked(arg, dims)
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
......@@ -378,14 +390,6 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
# Validate that AOT compile is not used with unbacked dynamic
# shapes. aot_compile re-allocates backed symbols post dynamo!
if ds_type == DynamicShapesType.UNBACKED:
raise ValueError(
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
"when VLLM_USE_AOT_COMPILE is enabled."
)
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
......@@ -488,6 +492,12 @@ def _support_torch_compile(
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0.dev+
inductor_config_patches = {}
if is_torch_equal_or_newer("2.10.0.dev"):
inductor_config_patches["assume_32bit_indexing"] = True
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
......@@ -496,6 +506,7 @@ def _support_torch_compile(
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches),
):
if envs.VLLM_USE_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
......
......@@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
......@@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
......@@ -142,7 +136,7 @@ class RMSNormQuantPattern:
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0
key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
......@@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
......@@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
......@@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
)
scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
at = auto_functionalized(
self.FUSED_OP,
result=result,
......@@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None,
residual=residual,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
is_scale_transposed=self.has_col_major_scales,
)
# result, residual, scale
......@@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape,
symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(
......@@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
)
self.group_shape = group_shape
super().__init__(epsilon, key)
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor):
......@@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales
input, transposed=self.quant_matcher.has_col_major_scales
)
at = auto_functionalized(
self.FUSED_OP,
......@@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None,
residual=None,
group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales,
is_scale_transposed=self.quant_matcher.has_col_major_scales,
)
# result, scale
......@@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
......@@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
......
......@@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
self,
quant_key: QuantKey,
enabled: bool | None = None,
use_col_major_scales: bool = False,
use_e8m0: bool = False,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
)
def forward_custom(
self,
......@@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
if self.quant_key.scale.group_shape.is_per_group():
assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales)
scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min
......@@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0,
scale_ue8m0=self.is_e8m0,
)
return result, scale
......
......@@ -53,12 +53,7 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.is_encoder_compilation = vllm_backend.is_encoder
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
......
......@@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
compiled_ptr = self.check_invariants_and_forward
aot_context = nullcontext()
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
else:
msg = "torch._dynamo.config.enable_aot_compile is not "
msg += "available. AOT compile is disabled and please "
msg += "upgrade PyTorch version to use AOT compile."
logger.warning(msg)
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)
with aot_context:
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
......
......@@ -8,7 +8,7 @@ from dataclasses import field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import Field, TypeAdapter, field_validator
from pydantic import ConfigDict, Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
......@@ -17,7 +17,6 @@ from vllm.config.utils import (
Range,
config,
get_hash_factors,
handle_deprecated,
hash_factors,
)
from vllm.logger import init_logger
......@@ -97,7 +96,7 @@ class CUDAGraphMode(enum.Enum):
@config
@dataclass
@dataclass(config=ConfigDict(extra="forbid"))
class PassConfig:
"""Configuration for custom Inductor passes.
......@@ -127,27 +126,6 @@ class PassConfig:
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
......@@ -206,15 +184,7 @@ class PassConfig:
Any future fields that don't affect compilation should be excluded.
"""
ignored_fields = [
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
return hash_factors(get_hash_factors(self, set()))
@field_validator(
"fuse_norm_quant",
......@@ -224,12 +194,6 @@ class PassConfig:
"enable_sp",
"fuse_gemm_comms",
"fuse_allreduce_rms",
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
mode="wrap",
)
@classmethod
......@@ -242,49 +206,6 @@ class PassConfig:
def __post_init__(self) -> None:
# Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated(
self,
"enable_fusion",
["fuse_norm_quant", "fuse_act_quant"],
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_attn_fusion",
"fuse_attn_quant",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_sequence_parallelism",
"enable_sp",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_async_tp",
"fuse_gemm_comms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_fi_allreduce_fusion",
"fuse_allreduce_rms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_noop",
"eliminate_noops",
"v0.13.0 or v1.0.0, whichever is sooner",
)
if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once(
......@@ -330,7 +251,7 @@ class DynamicShapesType(str, enum.Enum):
@config
@dataclass
@dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""
......@@ -369,7 +290,7 @@ class DynamicShapesConfig:
@config
@dataclass
@dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig:
"""Configuration for compilation.
......@@ -1011,9 +932,13 @@ class CompilationConfig:
self.splitting_ops = list(self._attention_ops)
added_default_splitting_ops = True
elif len(self.splitting_ops) == 0:
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
)
if (
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
):
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
)
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"
......
......@@ -64,6 +64,11 @@ class KVTransferConfig:
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks (default)
'fail': immediately fail the request with an error finish reason"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......
......@@ -8,7 +8,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
......@@ -73,17 +73,6 @@ logger = init_logger(__name__)
RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal[
"auto",
"generate",
"embedding",
"embed",
"classify",
"score",
"reward",
"transcription",
"draft",
]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
......@@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
"generate": ["generate", "transcription"],
"pooling": ["embedding", "embed", "classify", "score", "reward"],
"draft": ["draft"],
}
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [],
"pooling": ["embed", "classify", "reward"],
......@@ -126,13 +109,7 @@ class ModelConfig:
"""Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks."""
task: TaskOption | None = None
"""[DEPRECATED] The task to use the model for. If the model supports more
than one model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner.
"""
tokenizer: SkipValidation[str] = None # type: ignore
tokenizer: str = Field(default=None)
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto"
......@@ -187,7 +164,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: SkipValidation[int] = None # type: ignore
max_model_len: int = Field(default=None, gt=0)
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
......@@ -198,7 +175,7 @@ class ModelConfig:
- 25.6k -> 25,600"""
spec_target_max_model_len: int | None = None
"""Specify the maximum length for spec decoding draft models."""
quantization: SkipValidation[QuantizationMethods | None] = None
quantization: QuantizationMethods | str | None = None
"""Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
......@@ -338,7 +315,6 @@ class ModelConfig:
ignored_factors = {
"runner",
"convert",
"task",
"tokenizer",
"tokenizer_mode",
"seed",
......@@ -513,97 +489,6 @@ class ModelConfig:
is_generative_model = registry.is_text_generation_model(architectures, self)
is_pooling_model = registry.is_pooling_model(architectures, self)
def _task_to_convert(task: TaskOption) -> ConvertType:
if task == "embedding" or task == "embed":
return "embed"
if task == "classify":
return "classify"
if task == "reward":
logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score":
new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed"
return "none"
if self.task is not None:
runner: RunnerOption = "auto"
convert: ConvertOption = "auto"
msg_prefix = (
"The 'task' option has been deprecated and will be "
"removed in v0.13.0 or v1.0, whichever comes first."
)
msg_hint = "Please remove this option."
is_generative_task = self.task in _RUNNER_TASKS["generate"]
is_pooling_task = self.task in _RUNNER_TASKS["pooling"]
if is_generative_model and is_pooling_model:
if is_generative_task:
runner = "generate"
convert = "auto"
msg_hint = (
"Please replace this option with `--runner "
"generate` to continue using this model "
"as a generative model."
)
elif is_pooling_task:
runner = "pooling"
convert = "auto"
msg_hint = (
"Please replace this option with `--runner "
"pooling` to continue using this model "
"as a pooling model."
)
else: # task == "auto"
pass
elif is_generative_model or is_pooling_model:
if is_generative_task:
runner = "generate"
convert = "auto"
msg_hint = "Please remove this option"
elif is_pooling_task:
runner = "pooling"
convert = _task_to_convert(self.task)
msg_hint = (
"Please replace this option with `--convert "
f"{convert}` to continue using this model "
"as a pooling model."
)
else: # task == "auto"
pass
else:
# Neither generative nor pooling model - try to convert if possible
if is_pooling_task:
runner = "pooling"
convert = _task_to_convert(self.task)
msg_hint = (
"Please replace this option with `--runner pooling "
f"--convert {convert}` to continue using this model "
"as a pooling model."
)
else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError(
"The model should be a generative or "
"pooling model when task is set to "
f"{self.task!r}. Found: {debug_info}"
)
self.runner = runner
self.convert = convert
msg = f"{msg_prefix} {msg_hint}"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.runner_type = self._get_runner_type(architectures, self.runner)
self.convert_type = self._get_convert_type(
architectures, self.runner_type, self.convert
......@@ -657,6 +542,11 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
if self.is_encoder_decoder:
self.mm_processor_cache_gb = 0
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
# Init multimodal config if needed
if self._model_info.supports_multimodal:
if (
......@@ -710,6 +600,14 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
@field_validator("tokenizer", "max_model_len", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
@field_validator("tokenizer_mode", mode="after")
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
return tokenizer_mode.lower()
......@@ -723,10 +621,19 @@ class ModelConfig:
@model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
"""Called after __post_init__"""
if not isinstance(self.tokenizer, str):
raise ValueError("tokenizer must be a string after __post_init__.")
raise ValueError(
f"tokenizer must be a string, got "
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
"Please provide a valid tokenizer path or HuggingFace model ID."
)
if not isinstance(self.max_model_len, int):
raise ValueError("max_model_len must be an integer after __post_init__.")
raise ValueError(
f"max_model_len must be a positive integer, "
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "
"Example: max_model_len=2048"
)
return self
def _get_transformers_backend_cls(self) -> str:
......@@ -906,6 +813,13 @@ class ModelConfig:
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert == "reward":
logger.warning(
"`--convert reward` is deprecated and will be removed in v0.15. "
"Please use `--convert embed` instead."
)
return "embed"
if convert != "auto":
return convert
......@@ -921,22 +835,6 @@ class ModelConfig:
return convert_type
def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed", "classify", "reward"]:
if self.registry.is_cross_encoder_model(architectures, self):
return "classify"
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type
return "embed"
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
quant_cfg = getattr(hf_config, "quantization_config", None)
if quant_cfg is None:
......@@ -1308,7 +1206,15 @@ class ModelConfig:
// block.attention.n_heads_in_group
)
raise RuntimeError("Couldn't determine number of kv heads")
raise RuntimeError(
"Could not determine the number of key-value attention heads "
"from model configuration. "
f"Model: {self.model}, Architecture: {self.architectures}. "
"This usually indicates an unsupported model architecture or "
"missing configuration. "
"Please check if your model is supported at: "
"https://docs.vllm.ai/en/latest/models/supported_models.html"
)
if self.is_attention_free:
return 0
......@@ -1902,6 +1808,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForTokenClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
......
......@@ -317,11 +317,6 @@ class ParallelConfig:
"num_redundant_experts."
)
if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self
@property
......
......@@ -111,13 +111,15 @@ class PoolerConfig:
def get_use_activation(o: object):
if softmax := getattr(o, "softmax", None) is not None:
logger.warning_once(
"softmax will be deprecated, please use use_activation instead."
"softmax will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return softmax
if activation := getattr(o, "activation", None) is not None:
logger.warning_once(
"activation will be deprecated, please use use_activation instead."
"activation will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return activation
......
......@@ -122,10 +122,12 @@ class SchedulerConfig:
the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class"."""
disable_hybrid_kv_cache_manager: bool = False
disable_hybrid_kv_cache_manager: bool | None = None
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention.
If set to None, the default value will be determined based on the environment
and starting configuration.
"""
async_scheduling: bool = False
......
......@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
)
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
"""
for name in names:
for i, name in enumerate(names):
if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name)
return default
......
......@@ -666,8 +666,9 @@ class VllmConfig:
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.info(
......@@ -692,22 +693,29 @@ class VllmConfig:
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and self.model_config is not None
):
if self.model_config.pooler_config is not None:
if model_config := self.model_config:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder:
logger.warning_once(
"Encoder-decoder models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
......@@ -742,27 +750,17 @@ class VllmConfig:
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
if (
self.model_config
and self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if (
self.kv_events_config is not None
......@@ -812,11 +810,6 @@ class VllmConfig:
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
......@@ -894,17 +887,48 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
if not self.scheduler_config.disable_hybrid_kv_cache_manager:
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
# - No preference: auto-disable for unsupported features (e.g. kv connector)
# - Explicit disable (--disable-kv-cache-manager): always respect it
need_disable_hybrid_kv_cache_manager = False
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
need_disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
need_disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
need_disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
need_disable_hybrid_kv_cache_manager = True
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector for now.
# TODO(Kuntai): have a more elegent solution to check and
# turn off HMA for connector that does not support HMA.
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
......@@ -912,33 +936,26 @@ class VllmConfig:
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py."
" of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
)
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
elif (
self.scheduler_config.disable_hybrid_kv_cache_manager is False
and need_disable_hybrid_kv_cache_manager
):
raise ValueError(
"Hybrid KV cache manager was explicitly enabled but is not "
"supported in this configuration. Consider omitting the "
"--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
" automatically."
)
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
......@@ -1006,7 +1023,7 @@ class VllmConfig:
max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
......@@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock = threading.Lock()
def memory_fence():
"""
Full memory barrier for shared memory synchronization.
Ensures all prior memory writes are visible to other processes before
any subsequent reads. This is critical for lock-free producer-consumer
patterns using shared memory.
Implementation acquires and immediately releases a lock. Python's
threading.Lock provides sequentially consistent memory barrier semantics
across all major platforms (POSIX, Windows). This is a lightweight
operation (~20ns) that guarantees:
- All stores before the barrier are visible to other threads/processes
- All loads after the barrier see the latest values
"""
# Lock acquire/release provides full memory barrier semantics.
# Using context manager ensures lock release even on exceptions.
with _memory_fence_lock:
pass
def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")
......@@ -414,6 +442,10 @@ class MessageQueue:
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest read flags from readers.
# Without this, we may read stale flags from our CPU cache and
# spin indefinitely even though readers have completed.
memory_fence()
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
......@@ -458,6 +490,10 @@ class MessageQueue:
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
# Memory fence ensures the write is visible to readers on other cores
# before we proceed. Without this, readers may spin indefinitely
# waiting for a write that's stuck in our CPU's store buffer.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
......@@ -473,6 +509,10 @@ class MessageQueue:
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest writes from the writer.
# Without this, we may read stale flags from our CPU cache
# and spin indefinitely even though writer has updated them.
memory_fence()
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
......@@ -513,6 +553,10 @@ class MessageQueue:
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
# Memory fence ensures the read flag is visible to the writer.
# Without this, writer may not see our read completion and
# could wait indefinitely for all readers to finish.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
self._read_spin_timer.record_activity()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment