Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
...@@ -355,7 +355,7 @@ def kernel_unified_attention_2d( ...@@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
@triton.jit @triton.jit
def kernel_unified_attention_3d( def kernel_unified_attention_3d(
segm_output_ptr, segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size] # [num_tokens, num_query_heads, num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments] segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size]
...@@ -749,6 +749,11 @@ def unified_attention( ...@@ -749,6 +749,11 @@ def unified_attention(
q_descale, q_descale,
k_descale, k_descale,
v_descale, v_descale,
seq_threshold_3D=None,
num_par_softmax_segments=None,
softmax_segm_output=None,
softmax_segm_max=None,
softmax_segm_expsum=None,
alibi_slopes=None, alibi_slopes=None,
output_scale=None, output_scale=None,
qq_bias=None, qq_bias=None,
...@@ -793,8 +798,19 @@ def unified_attention( ...@@ -793,8 +798,19 @@ def unified_attention(
TILE_SIZE_PREFILL = 32 TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# if batch contains a prefill # Launch the 2D kernel if
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
if (
seq_threshold_3D is None
or num_par_softmax_segments is None
or softmax_segm_output is None
or softmax_segm_max is None
or softmax_segm_expsum is None
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
):
kernel_unified_attention_2d[ kernel_unified_attention_2d[
( (
total_num_q_blocks, total_num_q_blocks,
...@@ -847,37 +863,12 @@ def unified_attention( ...@@ -847,37 +863,12 @@ def unified_attention(
USE_FP8=output_scale is not None, USE_FP8=output_scale is not None,
) )
else: else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default kernel_unified_attention_3d[
# value that showed good performance in tests (total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
NUM_SEGMENTS = 16 ](
segm_output_ptr=softmax_segm_output,
segm_output = torch.empty( segm_max_ptr=softmax_segm_max,
q.shape[0], segm_expsum_ptr=softmax_segm_expsum,
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
query_ptr=q, query_ptr=q,
key_cache_ptr=k, key_cache_ptr=k,
value_cache_ptr=v, value_cache_ptr=v,
...@@ -917,13 +908,13 @@ def unified_attention( ...@@ -917,13 +908,13 @@ def unified_attention(
BLOCK_Q=BLOCK_Q, BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs, num_seqs=num_seqs,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
) )
reduce_segments[(q.shape[0], num_query_heads)]( reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out, output_ptr=out,
segm_output_ptr=segm_output, segm_output_ptr=softmax_segm_output,
segm_max_ptr=segm_max, segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=segm_expsum, segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k, seq_lens_ptr=seqused_k,
num_seqs=num_seqs, num_seqs=num_seqs,
num_query_heads=num_query_heads, num_query_heads=num_query_heads,
...@@ -936,6 +927,6 @@ def unified_attention( ...@@ -936,6 +927,6 @@ def unified_attention(
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q, query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q, BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None, USE_FP8=output_scale is not None,
) )
...@@ -16,6 +16,7 @@ import einops ...@@ -16,6 +16,7 @@ import einops
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -44,9 +45,7 @@ def flash_attn_maxseqlen_wrapper( ...@@ -44,9 +45,7 @@ def flash_attn_maxseqlen_wrapper(
dropout_p=0.0, dropout_p=0.0,
causal=False, causal=False,
) )
context_layer = einops.rearrange( context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layer return context_layer
...@@ -59,8 +58,7 @@ def flash_attn_maxseqlen_wrapper_fake( ...@@ -59,8 +58,7 @@ def flash_attn_maxseqlen_wrapper_fake(
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
) -> torch.Tensor: ) -> torch.Tensor:
b, s, h, d = q.shape return torch.empty_like(q)
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op( direct_register_custom_op(
...@@ -92,6 +90,13 @@ def torch_sdpa_wrapper( ...@@ -92,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = [] outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
...@@ -106,7 +111,6 @@ def torch_sdpa_wrapper( ...@@ -106,7 +111,6 @@ def torch_sdpa_wrapper(
output_i = einops.rearrange(output_i, "b h s d -> b s h d ") output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
return context_layer return context_layer
...@@ -116,8 +120,7 @@ def torch_sdpa_wrapper_fake( ...@@ -116,8 +120,7 @@ def torch_sdpa_wrapper_fake(
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
b, s, h, d = q.shape return torch.empty_like(q)
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
direct_register_custom_op( direct_register_custom_op(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from functools import cache from functools import cache
from typing import cast, get_args from typing import NamedTuple, cast, get_args
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import ( from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP, MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum, MambaAttentionBackendEnum,
...@@ -19,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -19,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -44,8 +68,7 @@ def get_attn_backend( ...@@ -44,8 +68,7 @@ def get_attn_backend(
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend( attn_selector_config = AttentionSelectorConfig(
backend=backend_enum,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
...@@ -54,58 +77,26 @@ def get_attn_backend( ...@@ -54,58 +77,26 @@ def get_attn_backend(
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse, use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix, use_mm_prefix=use_mm_prefix,
attn_type=attn_type, attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
) )
@cache @cache
def _cached_get_attn_backend( def _cached_get_attn_backend(
backend, backend,
head_size: int, attn_selector_config: AttentionSelectorConfig,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls) attention_cls = current_platform.get_attn_backend_cls(
if "use_v1" in sig.parameters: backend,
logger.warning_once( attn_selector_config=attn_selector_config,
"use_v1 parameter for get_attn_backend_cls is deprecated and will " )
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True, # use_v1
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(
f"Invalid attention backend for {current_platform.device_name}" f"Invalid attention backend for {current_platform.device_name}"
......
...@@ -235,7 +235,9 @@ async def get_request( ...@@ -235,7 +235,9 @@ async def get_request(
def calculate_metrics_for_embeddings( def calculate_metrics_for_embeddings(
outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float] outputs: list[RequestFuncOutput],
dur_s: float,
selected_percentiles: list[float],
) -> EmbedBenchmarkMetrics: ) -> EmbedBenchmarkMetrics:
"""Calculate the metrics for the embedding requests. """Calculate the metrics for the embedding requests.
...@@ -788,7 +790,7 @@ async def benchmark( ...@@ -788,7 +790,7 @@ async def benchmark(
) )
print( print(
"{:<40} {:<10.2f}".format( "{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput "Total token throughput (tok/s):", metrics.total_token_throughput
) )
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark the cold and warm startup time of vLLM models.
This script measures total startup time (including model loading, compilation,
and cache operations) for both cold and warm scenarios:
- Cold startup: Fresh start with no caches (temporary cache directories)
- Warm startup: Using cached compilation and model info
"""
import argparse
import dataclasses
import json
import multiprocessing
import os
import shutil
import tempfile
import time
from contextlib import contextmanager
from typing import Any
import numpy as np
from tqdm import tqdm
from vllm.benchmarks.lib.utils import (
convert_to_pytorch_benchmark_format,
write_to_json,
)
from vllm.engine.arg_utils import EngineArgs
@contextmanager
def cold_startup():
"""
Context manager to measure cold startup time:
1. Uses a temporary directory for vLLM cache to avoid any pollution
between cold startup iterations.
2. Uses inductor's fresh_cache to clear torch.compile caches.
"""
from torch._inductor.utils import fresh_cache
# Use temporary directory for caching to avoid any pollution between cold startups
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
try:
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
with fresh_cache():
yield
finally:
# Clean up temporary cache directory
shutil.rmtree(temp_cache_dir, ignore_errors=True)
if original_cache_root:
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
else:
os.environ.pop("VLLM_CACHE_ROOT", None)
def run_startup_in_subprocess(engine_args_dict, result_queue):
"""
Run LLM startup in a subprocess and return timing metrics via a queue.
This ensures complete isolation between iterations.
"""
try:
# Import inside the subprocess to avoid issues with forking
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
engine_args = EngineArgs(**engine_args_dict)
# Measure total startup time
start_time = time.perf_counter()
llm = LLM(**dataclasses.asdict(engine_args))
total_startup_time = time.perf_counter() - start_time
# Extract compilation time if available
compilation_time = 0.0
if hasattr(llm.llm_engine, "vllm_config"):
vllm_config = llm.llm_engine.vllm_config
if (
hasattr(vllm_config, "compilation_config")
and vllm_config.compilation_config is not None
):
compilation_time = vllm_config.compilation_config.compilation_time
result_queue.put(
{
"total_startup_time": total_startup_time,
"compilation_time": compilation_time,
}
)
except Exception as e:
result_queue.put(None)
result_queue.put(str(e))
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
base_name = os.path.splitext(args.output_json)[0]
cold_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_startup_time": results["avg_cold_startup_time"],
},
extra_info={
"cold_startup_times": results["cold_startup_times"],
"cold_startup_percentiles": results["cold_startup_percentiles"],
},
)
if cold_startup_records:
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
cold_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_compilation_time": results["avg_cold_compilation_time"],
},
extra_info={
"cold_compilation_times": results["cold_compilation_times"],
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
},
)
if cold_compilation_records:
write_to_json(
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
)
warm_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_startup_time": results["avg_warm_startup_time"],
},
extra_info={
"warm_startup_times": results["warm_startup_times"],
"warm_startup_percentiles": results["warm_startup_percentiles"],
},
)
if warm_startup_records:
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
warm_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_compilation_time": results["avg_warm_compilation_time"],
},
extra_info={
"warm_compilation_times": results["warm_compilation_times"],
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
},
)
if warm_compilation_records:
write_to_json(
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-iters-cold",
type=int,
default=5,
help="Number of cold startup iterations.",
)
parser.add_argument(
"--num-iters-warmup",
type=int,
default=3,
help="Number of warmup iterations before benchmarking warm startups.",
)
parser.add_argument(
"--num-iters-warm",
type=int,
default=5,
help="Number of warm startup iterations.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the startup time results in JSON format.",
)
parser = EngineArgs.add_cli_args(parser)
return parser
def main(args: argparse.Namespace):
# Set multiprocessing start method to 'spawn' for clean process isolation
# This ensures each subprocess starts fresh without inheriting state
multiprocessing.set_start_method("spawn", force=True)
engine_args = EngineArgs.from_cli_args(args)
def create_llm_and_measure_startup():
"""
Create LLM instance in a subprocess and measure startup time.
Returns timing metrics, using subprocess for complete isolation.
"""
# Convert engine_args to dictionary for pickling
engine_args_dict = dataclasses.asdict(engine_args)
# Create a queue for inter-process communication
result_queue = multiprocessing.Queue()
process = multiprocessing.Process(
target=run_startup_in_subprocess,
args=(
engine_args_dict,
result_queue,
),
)
process.start()
process.join()
if not result_queue.empty():
result = result_queue.get()
if result is None:
if not result_queue.empty():
error_msg = result_queue.get()
raise RuntimeError(f"Subprocess failed: {error_msg}")
else:
raise RuntimeError("Subprocess failed with unknown error")
return result
else:
raise RuntimeError("Subprocess did not return a result")
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
print("Measuring cold startup time...\n")
cold_startup_times = []
cold_compilation_times = []
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
with cold_startup():
metrics = create_llm_and_measure_startup()
cold_startup_times.append(metrics["total_startup_time"])
cold_compilation_times.append(metrics["compilation_time"])
# Warmup for warm startup
print("\nWarming up for warm startup measurement...\n")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
create_llm_and_measure_startup()
print("\nMeasuring warm startup time...\n")
warm_startup_times = []
warm_compilation_times = []
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
metrics = create_llm_and_measure_startup()
warm_startup_times.append(metrics["total_startup_time"])
warm_compilation_times.append(metrics["compilation_time"])
# Calculate statistics
cold_startup_array = np.array(cold_startup_times)
cold_compilation_array = np.array(cold_compilation_times)
warm_startup_array = np.array(warm_startup_times)
warm_compilation_array = np.array(warm_compilation_times)
avg_cold_startup = np.mean(cold_startup_array)
avg_cold_compilation = np.mean(cold_compilation_array)
avg_warm_startup = np.mean(warm_startup_array)
avg_warm_compilation = np.mean(warm_compilation_array)
percentages = [10, 25, 50, 75, 90, 99]
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
print("\n" + "=" * 60)
print("STARTUP TIME BENCHMARK RESULTS")
print("=" * 60)
# Cold startup statistics
print("\nCOLD STARTUP:")
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, cold_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
# Warm startup statistics
print("\nWARM STARTUP:")
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, warm_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("=" * 60)
# Output JSON results if specified
if args.output_json:
results = {
"avg_cold_startup_time": float(avg_cold_startup),
"avg_cold_compilation_time": float(avg_cold_compilation),
"cold_startup_times": cold_startup_times,
"cold_compilation_times": cold_compilation_times,
"cold_startup_percentiles": dict(
zip(percentages, cold_startup_percentiles.tolist())
),
"cold_compilation_percentiles": dict(
zip(percentages, cold_compilation_percentiles.tolist())
),
"avg_warm_startup_time": float(avg_warm_startup),
"avg_warm_compilation_time": float(avg_warm_compilation),
"warm_startup_times": warm_startup_times,
"warm_compilation_times": warm_compilation_times,
"warm_startup_percentiles": dict(
zip(percentages, warm_startup_percentiles.tolist())
),
"warm_compilation_percentiles": dict(
zip(percentages, warm_compilation_percentiles.tolist())
),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
...@@ -141,7 +141,25 @@ class CompilerManager: ...@@ -141,7 +141,25 @@ class CompilerManager:
# we use ast.literal_eval to parse the data # we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals. # because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe. # do not use eval(), it is unsafe.
self.cache = ast.literal_eval(f.read()) cache = ast.literal_eval(f.read())
def check_type(value, ty):
if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
def parse_key(key: Any) -> tuple[Range, int, str]:
range_tuple, graph_index, compiler_name = key
check_type(graph_index, int)
check_type(compiler_name, str)
if isinstance(range_tuple, tuple):
start, end = range_tuple
check_type(start, int)
check_type(end, int)
range_tuple = Range(start=start, end=end)
check_type(range_tuple, Range)
return range_tuple, graph_index, compiler_name
self.cache = {parse_key(key): value for key, value in cache.items()}
self.compiler.initialize_cache( self.compiler.initialize_cache(
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
...@@ -445,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -445,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# the tag for the part of model being compiled, # the tag for the part of model being compiled,
# e.g. backbone/eagle_head # e.g. backbone/eagle_head
model_tag: str = "backbone" model_tag: str = "backbone"
model_is_encoder: bool = False
@contextmanager @contextmanager
def set_model_tag(tag: str): def set_model_tag(tag: str, is_encoder: bool = False):
"""Context manager to set the model tag.""" """Context manager to set the model tag."""
global model_tag global model_tag
global model_is_encoder
assert tag != model_tag, ( assert tag != model_tag, (
f"Model tag {tag} is the same as the current tag {model_tag}." f"Model tag {tag} is the same as the current tag {model_tag}."
) )
old_tag = model_tag old_tag = model_tag
old_is_encoder = model_is_encoder
model_tag = tag model_tag = tag
model_is_encoder = is_encoder
try: try:
yield yield
finally: finally:
model_tag = old_tag model_tag = old_tag
model_is_encoder = old_is_encoder
class VllmBackend: class VllmBackend:
...@@ -505,6 +529,9 @@ class VllmBackend: ...@@ -505,6 +529,9 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc. # them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag self.prefix = prefix or model_tag
# Mark compilation for encoder.
self.is_encoder = model_is_encoder
# Passes to run on the graph post-grad. # Passes to run on the graph post-grad.
self.pass_manager = resolve_obj_by_qualname( self.pass_manager = resolve_obj_by_qualname(
current_platform.get_pass_manager_cls() current_platform.get_pass_manager_cls()
......
...@@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType ...@@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile
...@@ -316,7 +316,13 @@ def _support_torch_compile( ...@@ -316,7 +316,13 @@ def _support_torch_compile(
def _mark_dynamic_inputs(mod, type, *args, **kwargs): def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims): def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED: if type == DynamicShapesType.UNBACKED:
torch._dynamo.decorators.mark_unbacked(arg, dims) if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else: else:
torch._dynamo.mark_dynamic(arg, dims) torch._dynamo.mark_dynamic(arg, dims)
...@@ -350,7 +356,13 @@ def _support_torch_compile( ...@@ -350,7 +356,13 @@ def _support_torch_compile(
if isinstance(arg, torch.Tensor): if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing # In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.decorators.mark_unbacked(arg, dims) if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
...@@ -378,14 +390,6 @@ def _support_torch_compile( ...@@ -378,14 +390,6 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch. compile artifact from scratch.
""" """
# Validate that AOT compile is not used with unbacked dynamic
# shapes. aot_compile re-allocates backed symbols post dynamo!
if ds_type == DynamicShapesType.UNBACKED:
raise ValueError(
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
"when VLLM_USE_AOT_COMPILE is enabled."
)
from .caching import compilation_config_hash_factors from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config) factors: list[str] = compilation_config_hash_factors(self.vllm_config)
...@@ -488,6 +492,12 @@ def _support_torch_compile( ...@@ -488,6 +492,12 @@ def _support_torch_compile(
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS: if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True fx_config_patches["backed_size_oblivious"] = True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0.dev+
inductor_config_patches = {}
if is_torch_equal_or_newer("2.10.0.dev"):
inductor_config_patches["assume_32bit_indexing"] = True
with ( with (
patch.object( patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call InliningInstructionTranslator, "inline_call_", patched_inline_call
...@@ -496,6 +506,7 @@ def _support_torch_compile( ...@@ -496,6 +506,7 @@ def _support_torch_compile(
maybe_use_cudagraph_partition_wrapper(self.vllm_config), maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches), torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(), _torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches),
): ):
if envs.VLLM_USE_AOT_COMPILE: if envs.VLLM_USE_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
......
...@@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -23,17 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant, kNvfp4Quant,
kStaticTensorScale, kStaticTensorScale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear_for_nk,
)
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { ...@@ -118,21 +115,18 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
class RMSNormQuantPattern: class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey): def __init__(
self,
epsilon: float,
key: FusedRMSQuantKey,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
):
self.epsilon = epsilon self.epsilon = epsilon
self.quant_dtype = key.quant.dtype self.quant_dtype = key.quant.dtype
config = get_current_vllm_config() config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config else None
# groupwise FP8 linear uses col major scales if deepgemm and cutlass
using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk(
self.model_dtype,
config.model_config.hf_config.intermediate_size,
config.model_config.hf_config.hidden_size,
)
use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported()
use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key] self.FUSED_OP = FUSED_OPS[key]
...@@ -142,7 +136,7 @@ class RMSNormQuantPattern: ...@@ -142,7 +136,7 @@ class RMSNormQuantPattern:
else MatcherFusedAddRMSNorm(epsilon) else MatcherFusedAddRMSNorm(epsilon)
) )
self.quant_matcher = MatcherQuantFP8( self.quant_matcher = MatcherQuantFP8(
key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0 key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
) )
...@@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -260,6 +254,8 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
symmetric=True, symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
): ):
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
...@@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -267,7 +263,11 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
) )
self.group_shape = group_shape self.group_shape = group_shape
super().__init__(epsilon, key) self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
...@@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -283,9 +283,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
input = input.to(dtype=self.model_dtype) input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype) result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale( scale = self.quant_matcher.make_scale(input, self.has_col_major_scales)
input, transposed=self.quant_matcher.use_col_major_scales
)
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
result=result, result=result,
...@@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -296,7 +294,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None, scale_ub=None,
residual=residual, residual=residual,
group_size=self.group_shape[1], group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales, is_scale_transposed=self.has_col_major_scales,
) )
# result, residual, scale # result, residual, scale
...@@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -318,6 +316,8 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape, group_shape: GroupShape,
symmetric=True, symmetric=True,
has_col_major_scales: bool = False,
is_e8m0: bool = False,
): ):
scale = ScaleDesc(torch.float32, False, group_shape) scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey( key = FusedRMSQuantKey(
...@@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -325,7 +325,9 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
) )
self.group_shape = group_shape self.group_shape = group_shape
super().__init__(epsilon, key) super().__init__(
epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0
)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, weight: torch.Tensor): def pattern(input: torch.Tensor, weight: torch.Tensor):
...@@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -340,7 +342,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
result = torch.empty_like(input, dtype=self.quant_dtype) result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale( scale = self.quant_matcher.make_scale(
input, transposed=self.quant_matcher.use_col_major_scales input, transposed=self.quant_matcher.has_col_major_scales
) )
at = auto_functionalized( at = auto_functionalized(
self.FUSED_OP, self.FUSED_OP,
...@@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -352,7 +354,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
scale_ub=None, scale_ub=None,
residual=None, residual=None,
group_size=self.group_shape[1], group_size=self.group_shape[1],
is_scale_transposed=self.quant_matcher.use_col_major_scales, is_scale_transposed=self.quant_matcher.has_col_major_scales,
) )
# result, scale # result, scale
...@@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -489,27 +491,6 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Make sure fused add patterns are before simple rms norm, # Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops # as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64)
).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant # Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns self.patterns
...@@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -526,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# Fuse rms_norm + dynamic per-token fp8 quant # Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
for has_col_major_scales in [True, False]:
for is_e8m0 in [True, False]:
# Fuse fused_add_rms_norm + fp8 group quant
FusedAddRMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
# Fuse rms_norm + fp8 group quant
RMSNormGroupQuantPattern(
epsilon,
FP8_DTYPE,
group_shape=group_shape,
has_col_major_scales=has_col_major_scales,
is_e8m0=is_e8m0,
).register(self.patterns)
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
......
...@@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -234,24 +234,30 @@ class MatcherQuantFP8(MatcherCustomOp):
self, self,
quant_key: QuantKey, quant_key: QuantKey,
enabled: bool | None = None, enabled: bool | None = None,
use_col_major_scales: bool = False, has_col_major_scales: bool = False,
use_e8m0: bool = False, is_e8m0: bool = False,
): ):
if enabled is None: if enabled is None:
enabled = QuantFP8.enabled() enabled = QuantFP8.enabled()
super().__init__(enabled) super().__init__(enabled)
self.quant_key = quant_key self.quant_key = quant_key
self.use_col_major_scales = use_col_major_scales
self.use_e8m0 = use_e8m0
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key] self.QUANT_OP = QUANT_OPS[quant_key]
self.has_col_major_scales = has_col_major_scales
self.is_e8m0 = is_e8m0
assert quant_key.dtype == current_platform.fp8_dtype(), ( assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by" "Only QuantFP8 supported by"
) )
assert quant_key.scale2 is None assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) self.quant_fp8 = QuantFP8(
quant_key.scale.static,
quant_key.scale.group_shape,
column_major_scales=has_col_major_scales,
use_ue8m0=is_e8m0,
)
def forward_custom( def forward_custom(
self, self,
...@@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -264,7 +270,7 @@ class MatcherQuantFP8(MatcherCustomOp):
if self.quant_key.scale.group_shape.is_per_group(): if self.quant_key.scale.group_shape.is_per_group():
assert scale is None assert scale is None
scale = self.make_scale(input, transposed=self.use_col_major_scales) scale = self.make_scale(input, transposed=self.has_col_major_scales)
finfo = torch.finfo(self.quant_key.dtype) finfo = torch.finfo(self.quant_key.dtype)
fp8_min = finfo.min fp8_min = finfo.min
...@@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp): ...@@ -279,7 +285,7 @@ class MatcherQuantFP8(MatcherCustomOp):
eps=1e-10, eps=1e-10,
fp8_min=fp8_min, fp8_min=fp8_min,
fp8_max=fp8_max, fp8_max=fp8_max,
scale_ue8m0=self.use_e8m0, scale_ue8m0=self.is_e8m0,
) )
return result, scale return result, scale
......
...@@ -53,12 +53,7 @@ class PiecewiseBackend: ...@@ -53,12 +53,7 @@ class PiecewiseBackend:
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1 self.is_full_graph = total_piecewise_compiles == 1
# TODO: we need to generalize encoder compilation to other models self.is_encoder_compilation = vllm_backend.is_encoder
self.is_encoder_compilation = vllm_backend.prefix in [
"Qwen2_5_VisionPatchEmbed",
"Qwen2_5_VisionPatchMerger",
"Qwen2_5_VisionBlock",
]
self.compile_ranges = self.compilation_config.get_compile_ranges() self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation: if self.is_encoder_compilation:
......
...@@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
compiled_ptr = self.check_invariants_and_forward compiled_ptr = self.check_invariants_and_forward
aot_context = nullcontext()
if envs.VLLM_USE_AOT_COMPILE: if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"): if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
else: else:
msg = "torch._dynamo.config.enable_aot_compile is not " msg = "torch._dynamo.config.enable_aot_compile is not "
msg += "available. AOT compile is disabled and please " msg += "available. AOT compile is disabled and please "
msg += "upgrade PyTorch version to use AOT compile." msg += "upgrade PyTorch version to use AOT compile."
logger.warning(msg) logger.warning(msg)
self._compiled_callable = torch.compile( with aot_context:
compiled_ptr, self._compiled_callable = torch.compile(
fullgraph=True, compiled_ptr,
dynamic=False, fullgraph=True,
backend=backend, dynamic=False,
options=options, backend=backend,
) options=options,
)
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
......
...@@ -8,7 +8,7 @@ from dataclasses import field ...@@ -8,7 +8,7 @@ from dataclasses import field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import Field, TypeAdapter, field_validator from pydantic import ConfigDict, Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
...@@ -17,7 +17,6 @@ from vllm.config.utils import ( ...@@ -17,7 +17,6 @@ from vllm.config.utils import (
Range, Range,
config, config,
get_hash_factors, get_hash_factors,
handle_deprecated,
hash_factors, hash_factors,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -97,7 +96,7 @@ class CUDAGraphMode(enum.Enum): ...@@ -97,7 +96,7 @@ class CUDAGraphMode(enum.Enum):
@config @config
@dataclass @dataclass(config=ConfigDict(extra="forbid"))
class PassConfig: class PassConfig:
"""Configuration for custom Inductor passes. """Configuration for custom Inductor passes.
...@@ -127,27 +126,6 @@ class PassConfig: ...@@ -127,27 +126,6 @@ class PassConfig:
fuse_allreduce_rms: bool = Field(default=None) fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion.""" """Enable flashinfer allreduce fusion."""
# Deprecated flags
enable_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
"""
enable_attn_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_noop: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_sequence_parallelism: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use enable_sp instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_async_tp: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
fi_allreduce_fusion_max_size_mb: float | None = None fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which """The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a vllm should use flashinfer fused allreduce. Specified as a
...@@ -206,15 +184,7 @@ class PassConfig: ...@@ -206,15 +184,7 @@ class PassConfig:
Any future fields that don't affect compilation should be excluded. Any future fields that don't affect compilation should be excluded.
""" """
ignored_fields = [ return hash_factors(get_hash_factors(self, set()))
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
@field_validator( @field_validator(
"fuse_norm_quant", "fuse_norm_quant",
...@@ -224,12 +194,6 @@ class PassConfig: ...@@ -224,12 +194,6 @@ class PassConfig:
"enable_sp", "enable_sp",
"fuse_gemm_comms", "fuse_gemm_comms",
"fuse_allreduce_rms", "fuse_allreduce_rms",
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
mode="wrap", mode="wrap",
) )
@classmethod @classmethod
...@@ -242,49 +206,6 @@ class PassConfig: ...@@ -242,49 +206,6 @@ class PassConfig:
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Handle deprecation and defaults # Handle deprecation and defaults
# Map old flags to new flags and issue warnings
handle_deprecated(
self,
"enable_fusion",
["fuse_norm_quant", "fuse_act_quant"],
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_attn_fusion",
"fuse_attn_quant",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_sequence_parallelism",
"enable_sp",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_async_tp",
"fuse_gemm_comms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_fi_allreduce_fusion",
"fuse_allreduce_rms",
"v0.13.0 or v1.0.0, whichever is sooner",
)
handle_deprecated(
self,
"enable_noop",
"eliminate_noops",
"v0.13.0 or v1.0.0, whichever is sooner",
)
if not self.eliminate_noops: if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant: if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once( logger.warning_once(
...@@ -330,7 +251,7 @@ class DynamicShapesType(str, enum.Enum): ...@@ -330,7 +251,7 @@ class DynamicShapesType(str, enum.Enum):
@config @config
@dataclass @dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig: class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes.""" """Configuration to control/debug torch compile dynamic shapes."""
...@@ -369,7 +290,7 @@ class DynamicShapesConfig: ...@@ -369,7 +290,7 @@ class DynamicShapesConfig:
@config @config
@dataclass @dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig: class CompilationConfig:
"""Configuration for compilation. """Configuration for compilation.
...@@ -1011,9 +932,13 @@ class CompilationConfig: ...@@ -1011,9 +932,13 @@ class CompilationConfig:
self.splitting_ops = list(self._attention_ops) self.splitting_ops = list(self._attention_ops)
added_default_splitting_ops = True added_default_splitting_ops = True
elif len(self.splitting_ops) == 0: elif len(self.splitting_ops) == 0:
logger.warning_once( if (
"Using piecewise compilation with empty splitting_ops" self.cudagraph_mode == CUDAGraphMode.PIECEWISE
) or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
):
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
)
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once( logger.warning_once(
"Piecewise compilation with empty splitting_ops do not" "Piecewise compilation with empty splitting_ops do not"
......
...@@ -64,6 +64,11 @@ class KVTransferConfig: ...@@ -64,6 +64,11 @@ class KVTransferConfig:
enable_permute_local_kv: bool = False enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer""" """Experiment feature flag to enable HND to NHD KV Transfer"""
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
"""Policy for handling KV cache load failures.
'recompute': reschedule the request to recompute failed blocks (default)
'fail': immediately fail the request with an error finish reason"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
...@@ -8,7 +8,7 @@ from functools import cached_property ...@@ -8,7 +8,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES from transformers.configuration_utils import ALLOWED_LAYER_TYPES
...@@ -73,17 +73,6 @@ logger = init_logger(__name__) ...@@ -73,17 +73,6 @@ logger = init_logger(__name__)
RunnerOption = Literal["auto", RunnerType] RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"] ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType] ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal[
"auto",
"generate",
"embedding",
"embed",
"classify",
"score",
"reward",
"transcription",
"draft",
]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"] TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[ LogprobsMode = Literal[
...@@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig] ...@@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
LayerBlockType = Literal["attention", "linear_attention", "mamba"] LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
"generate": ["generate", "transcription"],
"pooling": ["embedding", "embed", "classify", "score", "reward"],
"draft": ["draft"],
}
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [], "generate": [],
"pooling": ["embed", "classify", "reward"], "pooling": ["embed", "classify", "reward"],
...@@ -126,13 +109,7 @@ class ModelConfig: ...@@ -126,13 +109,7 @@ class ModelConfig:
"""Convert the model using adapters defined in """Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to [vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks.""" adapt a text generation model to be used for pooling tasks."""
task: TaskOption | None = None tokenizer: str = Field(default=None)
"""[DEPRECATED] The task to use the model for. If the model supports more
than one model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner.
"""
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model """Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto" tokenizer_mode: TokenizerMode | str = "auto"
...@@ -187,7 +164,7 @@ class ModelConfig: ...@@ -187,7 +164,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub. """The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version.""" use the default version."""
max_model_len: SkipValidation[int] = None # type: ignore max_model_len: int = Field(default=None, gt=0)
"""Model context length (prompt and output). If unspecified, will be """Model context length (prompt and output). If unspecified, will be
automatically derived from the model config. automatically derived from the model config.
...@@ -198,7 +175,7 @@ class ModelConfig: ...@@ -198,7 +175,7 @@ class ModelConfig:
- 25.6k -> 25,600""" - 25.6k -> 25,600"""
spec_target_max_model_len: int | None = None spec_target_max_model_len: int | None = None
"""Specify the maximum length for spec decoding draft models.""" """Specify the maximum length for spec decoding draft models."""
quantization: SkipValidation[QuantizationMethods | None] = None quantization: QuantizationMethods | str | None = None
"""Method used to quantize the weights. If `None`, we first check the """Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is `quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to `None`, we assume the model weights are not quantized and use `dtype` to
...@@ -338,7 +315,6 @@ class ModelConfig: ...@@ -338,7 +315,6 @@ class ModelConfig:
ignored_factors = { ignored_factors = {
"runner", "runner",
"convert", "convert",
"task",
"tokenizer", "tokenizer",
"tokenizer_mode", "tokenizer_mode",
"seed", "seed",
...@@ -513,97 +489,6 @@ class ModelConfig: ...@@ -513,97 +489,6 @@ class ModelConfig:
is_generative_model = registry.is_text_generation_model(architectures, self) is_generative_model = registry.is_text_generation_model(architectures, self)
is_pooling_model = registry.is_pooling_model(architectures, self) is_pooling_model = registry.is_pooling_model(architectures, self)
def _task_to_convert(task: TaskOption) -> ConvertType:
if task == "embedding" or task == "embed":
return "embed"
if task == "classify":
return "classify"
if task == "reward":
logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score":
new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed"
return "none"
if self.task is not None:
runner: RunnerOption = "auto"
convert: ConvertOption = "auto"
msg_prefix = (
"The 'task' option has been deprecated and will be "
"removed in v0.13.0 or v1.0, whichever comes first."
)
msg_hint = "Please remove this option."
is_generative_task = self.task in _RUNNER_TASKS["generate"]
is_pooling_task = self.task in _RUNNER_TASKS["pooling"]
if is_generative_model and is_pooling_model:
if is_generative_task:
runner = "generate"
convert = "auto"
msg_hint = (
"Please replace this option with `--runner "
"generate` to continue using this model "
"as a generative model."
)
elif is_pooling_task:
runner = "pooling"
convert = "auto"
msg_hint = (
"Please replace this option with `--runner "
"pooling` to continue using this model "
"as a pooling model."
)
else: # task == "auto"
pass
elif is_generative_model or is_pooling_model:
if is_generative_task:
runner = "generate"
convert = "auto"
msg_hint = "Please remove this option"
elif is_pooling_task:
runner = "pooling"
convert = _task_to_convert(self.task)
msg_hint = (
"Please replace this option with `--convert "
f"{convert}` to continue using this model "
"as a pooling model."
)
else: # task == "auto"
pass
else:
# Neither generative nor pooling model - try to convert if possible
if is_pooling_task:
runner = "pooling"
convert = _task_to_convert(self.task)
msg_hint = (
"Please replace this option with `--runner pooling "
f"--convert {convert}` to continue using this model "
"as a pooling model."
)
else:
debug_info = {
"architectures": architectures,
"is_generative_model": is_generative_model,
"is_pooling_model": is_pooling_model,
}
raise AssertionError(
"The model should be a generative or "
"pooling model when task is set to "
f"{self.task!r}. Found: {debug_info}"
)
self.runner = runner
self.convert = convert
msg = f"{msg_prefix} {msg_hint}"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.runner_type = self._get_runner_type(architectures, self.runner) self.runner_type = self._get_runner_type(architectures, self.runner)
self.convert_type = self._get_convert_type( self.convert_type = self._get_convert_type(
architectures, self.runner_type, self.convert architectures, self.runner_type, self.convert
...@@ -657,6 +542,11 @@ class ModelConfig: ...@@ -657,6 +542,11 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
if self.is_encoder_decoder:
self.mm_processor_cache_gb = 0
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
# Init multimodal config if needed # Init multimodal config if needed
if self._model_info.supports_multimodal: if self._model_info.supports_multimodal:
if ( if (
...@@ -710,6 +600,14 @@ class ModelConfig: ...@@ -710,6 +600,14 @@ class ModelConfig:
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config() self._verify_bnb_config()
@field_validator("tokenizer", "max_model_len", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
@field_validator("tokenizer_mode", mode="after") @field_validator("tokenizer_mode", mode="after")
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str: def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
return tokenizer_mode.lower() return tokenizer_mode.lower()
...@@ -723,10 +621,19 @@ class ModelConfig: ...@@ -723,10 +621,19 @@ class ModelConfig:
@model_validator(mode="after") @model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
"""Called after __post_init__"""
if not isinstance(self.tokenizer, str): if not isinstance(self.tokenizer, str):
raise ValueError("tokenizer must be a string after __post_init__.") raise ValueError(
f"tokenizer must be a string, got "
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
"Please provide a valid tokenizer path or HuggingFace model ID."
)
if not isinstance(self.max_model_len, int): if not isinstance(self.max_model_len, int):
raise ValueError("max_model_len must be an integer after __post_init__.") raise ValueError(
f"max_model_len must be a positive integer, "
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "
"Example: max_model_len=2048"
)
return self return self
def _get_transformers_backend_cls(self) -> str: def _get_transformers_backend_cls(self) -> str:
...@@ -906,6 +813,13 @@ class ModelConfig: ...@@ -906,6 +813,13 @@ class ModelConfig:
runner_type: RunnerType, runner_type: RunnerType,
convert: ConvertOption, convert: ConvertOption,
) -> ConvertType: ) -> ConvertType:
if convert == "reward":
logger.warning(
"`--convert reward` is deprecated and will be removed in v0.15. "
"Please use `--convert embed` instead."
)
return "embed"
if convert != "auto": if convert != "auto":
return convert return convert
...@@ -921,22 +835,6 @@ class ModelConfig: ...@@ -921,22 +835,6 @@ class ModelConfig:
return convert_type return convert_type
def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed", "classify", "reward"]:
if self.registry.is_cross_encoder_model(architectures, self):
return "classify"
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type
return "embed"
def _parse_quant_hf_config(self, hf_config: PretrainedConfig): def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
quant_cfg = getattr(hf_config, "quantization_config", None) quant_cfg = getattr(hf_config, "quantization_config", None)
if quant_cfg is None: if quant_cfg is None:
...@@ -1308,7 +1206,15 @@ class ModelConfig: ...@@ -1308,7 +1206,15 @@ class ModelConfig:
// block.attention.n_heads_in_group // block.attention.n_heads_in_group
) )
raise RuntimeError("Couldn't determine number of kv heads") raise RuntimeError(
"Could not determine the number of key-value attention heads "
"from model configuration. "
f"Model: {self.model}, Architecture: {self.architectures}. "
"This usually indicates an unsupported model architecture or "
"missing configuration. "
"Please check if your model is supported at: "
"https://docs.vllm.ai/en/latest/models/supported_models.html"
)
if self.is_attention_free: if self.is_attention_free:
return 0 return 0
...@@ -1902,6 +1808,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ...@@ -1902,6 +1808,7 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForTextEncoding", ("pooling", "embed")), ("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")), ("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")), ("ForSequenceClassification", ("pooling", "classify")),
("ForTokenClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")), ("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")), ("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")),
......
...@@ -317,11 +317,6 @@ class ParallelConfig: ...@@ -317,11 +317,6 @@ class ParallelConfig:
"num_redundant_experts." "num_redundant_experts."
) )
if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self return self
@property @property
......
...@@ -111,13 +111,15 @@ class PoolerConfig: ...@@ -111,13 +111,15 @@ class PoolerConfig:
def get_use_activation(o: object): def get_use_activation(o: object):
if softmax := getattr(o, "softmax", None) is not None: if softmax := getattr(o, "softmax", None) is not None:
logger.warning_once( logger.warning_once(
"softmax will be deprecated, please use use_activation instead." "softmax will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
) )
return softmax return softmax
if activation := getattr(o, "activation", None) is not None: if activation := getattr(o, "activation", None) is not None:
logger.warning_once( logger.warning_once(
"activation will be deprecated, please use use_activation instead." "activation will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
) )
return activation return activation
......
...@@ -122,10 +122,12 @@ class SchedulerConfig: ...@@ -122,10 +122,12 @@ class SchedulerConfig:
the default scheduler. Can be a class directly or the path to a class of the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class".""" form "mod.custom_class"."""
disable_hybrid_kv_cache_manager: bool = False disable_hybrid_kv_cache_manager: bool | None = None
"""If set to True, KV cache manager will allocate the same size of KV cache """If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention. like full attention and sliding window attention.
If set to None, the default value will be determined based on the environment
and starting configuration.
""" """
async_scheduling: bool = False async_scheduling: bool = False
......
...@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field: ...@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
) )
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
) -> Any:
""" """
A helper function that retrieves an attribute from an object which may A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances. arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
""" """
for name in names: for i, name in enumerate(names):
if hasattr(object, name): if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name) return getattr(object, name)
return default return default
......
...@@ -666,8 +666,9 @@ class VllmConfig: ...@@ -666,8 +666,9 @@ class VllmConfig:
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config) self._apply_optimization_level_defaults(default_config)
if ( if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
): ):
logger.info( logger.info(
...@@ -692,22 +693,29 @@ class VllmConfig: ...@@ -692,22 +693,29 @@ class VllmConfig:
if current_platform.support_static_graph_mode(): if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support # if cudagraph_mode has full cudagraphs, we need to check support
if ( if model_config := self.model_config:
self.compilation_config.cudagraph_mode.has_full_cudagraphs() if (
and self.model_config is not None self.compilation_config.cudagraph_mode.has_full_cudagraphs()
): and model_config.pooler_config is not None
if self.model_config.pooler_config is not None: ):
logger.warning_once( logger.warning_once(
"Pooling models do not support full cudagraphs. " "Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE." "Overriding cudagraph_mode to PIECEWISE."
) )
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config.is_encoder_decoder: elif (
logger.warning_once( model_config.is_encoder_decoder
"Encoder-decoder models do not support full cudagraphs. " and self.compilation_config.cudagraph_mode
"Overriding cudagraph_mode to PIECEWISE." not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
) )
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution # disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager: if self.model_config is not None and self.model_config.enforce_eager:
...@@ -742,27 +750,17 @@ class VllmConfig: ...@@ -742,27 +750,17 @@ class VllmConfig:
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges() self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder: if (
from vllm.multimodal import MULTIMODAL_REGISTRY self.model_config
and self.model_config.architecture == "WhisperForConditionalGeneration"
self.scheduler_config.max_num_encoder_input_tokens = ( and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) ):
) logger.warning(
logger.debug( "Whisper is known to have issues with "
"Encoder-decoder model detected: setting " "forked workers. If startup is hanging, "
"`max_num_encoder_input_tokens` to encoder length (%s)", "try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
self.scheduler_config.max_num_encoder_input_tokens, "to 'spawn'."
) )
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if ( if (
self.kv_events_config is not None self.kv_events_config is not None
...@@ -812,11 +810,6 @@ class VllmConfig: ...@@ -812,11 +810,6 @@ class VllmConfig:
f"({self.parallel_config.cp_kv_cache_interleave_size})." f"({self.parallel_config.cp_kv_cache_interleave_size})."
) )
assert (
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode # Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1( self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend, all2all_backend=self.parallel_config.all2all_backend,
...@@ -894,17 +887,48 @@ class VllmConfig: ...@@ -894,17 +887,48 @@ class VllmConfig:
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
if not self.scheduler_config.disable_hybrid_kv_cache_manager: # Hybrid KV cache manager (HMA) runtime rules:
# logger should only print warning message for hybrid models. As we # - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# can't know whether the model is hybrid or not now, so we don't log # disables it
# warning message here and will log it later. # - No preference: auto-disable for unsupported features (e.g. kv connector)
if not current_platform.support_hybrid_kv_cache(): # - Explicit disable (--disable-kv-cache-manager): always respect it
# Hybrid KV cache manager is not supported on non-GPU platforms. need_disable_hybrid_kv_cache_manager = False
self.scheduler_config.disable_hybrid_kv_cache_manager = True # logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
need_disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
need_disable_hybrid_kv_cache_manager = True
if (
self.model_config is not None
and self.model_config.attention_chunk_size is not None
):
if (
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
# Hybrid KV cache manager is not yet supported with chunked
# local attention + eagle.
need_disable_hybrid_kv_cache_manager = True
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE:
logger.warning(
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
need_disable_hybrid_kv_cache_manager = True
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None: if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector for now. # NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
# TODO(Kuntai): have a more elegent solution to check and need_disable_hybrid_kv_cache_manager = True
# turn off HMA for connector that does not support HMA.
logger.warning( logger.warning(
"Turning off hybrid kv cache manager because " "Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the " "`--kv-transfer-config` is set. This will reduce the "
...@@ -912,33 +936,26 @@ class VllmConfig: ...@@ -912,33 +936,26 @@ class VllmConfig:
"or Mamba attention. If you are a developer of kv connector" "or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for " ", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass" "your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py." " of `SupportsHMA` defined in kv_connector/v1/base.py and"
" use --no-disable-hybrid-kv-cache-manager to start vLLM."
) )
self.scheduler_config.disable_hybrid_kv_cache_manager = True self.scheduler_config.disable_hybrid_kv_cache_manager = (
if self.kv_events_config is not None: need_disable_hybrid_kv_cache_manager
# Hybrid KV cache manager is not compatible with KV events. )
self.scheduler_config.disable_hybrid_kv_cache_manager = True elif (
if ( self.scheduler_config.disable_hybrid_kv_cache_manager is False
self.model_config is not None and need_disable_hybrid_kv_cache_manager
and self.model_config.attention_chunk_size is not None ):
): raise ValueError(
if ( "Hybrid KV cache manager was explicitly enabled but is not "
self.speculative_config is not None "supported in this configuration. Consider omitting the "
and self.speculative_config.use_eagle() "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide"
): " automatically."
# Hybrid KV cache manager is not yet supported with chunked )
# local attention + eagle.
self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: # Default to enable HMA if not explicitly disabled by user or logic above.
logger.warning( self.scheduler_config.disable_hybrid_kv_cache_manager = False
"There is a latency regression when using chunked local"
" attention with the hybrid KV cache manager. Disabling"
" it, by default. To enable it, set the environment "
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1."
)
# Hybrid KV cache manager is not yet supported with chunked
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.compilation_config.debug_dump_path: if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path = (
...@@ -1006,7 +1023,7 @@ class VllmConfig: ...@@ -1006,7 +1023,7 @@ class VllmConfig:
max_graph_size = min(max_num_seqs * 2, 512) max_graph_size = min(max_num_seqs * 2, 512)
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
# up to max_graph_size # up to max_graph_size
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
range(256, max_graph_size + 1, 16)) range(256, max_graph_size + 1, 16))
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import pickle import pickle
import threading
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
...@@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL ...@@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
from_bytes_big = functools.partial(int.from_bytes, byteorder="big") from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock = threading.Lock()
def memory_fence():
"""
Full memory barrier for shared memory synchronization.
Ensures all prior memory writes are visible to other processes before
any subsequent reads. This is critical for lock-free producer-consumer
patterns using shared memory.
Implementation acquires and immediately releases a lock. Python's
threading.Lock provides sequentially consistent memory barrier semantics
across all major platforms (POSIX, Windows). This is a lightweight
operation (~20ns) that guarantees:
- All stores before the barrier are visible to other threads/processes
- All loads after the barrier see the latest values
"""
# Lock acquire/release provides full memory barrier semantics.
# Using context manager ensures lock release even on exceptions.
with _memory_fence_lock:
pass
def to_bytes_big(value: int, size: int) -> bytes: def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big") return value.to_bytes(size, byteorder="big")
...@@ -414,6 +442,10 @@ class MessageQueue: ...@@ -414,6 +442,10 @@ class MessageQueue:
n_warning = 1 n_warning = 1
while True: while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest read flags from readers.
# Without this, we may read stale flags from our CPU cache and
# spin indefinitely even though readers have completed.
memory_fence()
read_count = sum(metadata_buffer[1:]) read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0] written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader: if written_flag and read_count != self.buffer.n_reader:
...@@ -458,6 +490,10 @@ class MessageQueue: ...@@ -458,6 +490,10 @@ class MessageQueue:
metadata_buffer[i] = 0 metadata_buffer[i] = 0
# mark the block as written # mark the block as written
metadata_buffer[0] = 1 metadata_buffer[0] = 1
# Memory fence ensures the write is visible to readers on other cores
# before we proceed. Without this, readers may spin indefinitely
# waiting for a write that's stuck in our CPU's store buffer.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break break
...@@ -473,6 +509,10 @@ class MessageQueue: ...@@ -473,6 +509,10 @@ class MessageQueue:
n_warning = 1 n_warning = 1
while True: while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest writes from the writer.
# Without this, we may read stale flags from our CPU cache
# and spin indefinitely even though writer has updated them.
memory_fence()
read_flag = metadata_buffer[self.local_reader_rank + 1] read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0] written_flag = metadata_buffer[0]
if not written_flag or read_flag: if not written_flag or read_flag:
...@@ -513,6 +553,10 @@ class MessageQueue: ...@@ -513,6 +553,10 @@ class MessageQueue:
# caller has read from the buffer # caller has read from the buffer
# set the read flag # set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1 metadata_buffer[self.local_reader_rank + 1] = 1
# Memory fence ensures the read flag is visible to the writer.
# Without this, writer may not see our read completion and
# could wait indefinitely for all readers to finish.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
self._read_spin_timer.record_activity() self._read_spin_timer.record_activity()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment