Commit 41199996 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.12.0' into v0.12.0-dev

parents 31021d81 4fd9d6a8
......@@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0
MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number
```
#### 2. Maximize Throughput with a Latency Requirement
### 2. Maximize Throughput with a Latency Requirement
- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms.
- **Configuration**:
......@@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0
MAX_LATENCY_ALLOWED_MS=500
```
#### 3. Maximize Throughput with Prefix Caching and Latency Requirements
### 3. Maximize Throughput with Prefix Caching and Latency Requirements
- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
- **Configuration**:
......
......@@ -74,7 +74,7 @@ start_server() {
local vllm_log=$4
local profile_dir=$5
pkill -if vllm
pkill -if "vllm serve" || true
# Define the common arguments as a bash array.
# Each argument and its value are separate elements.
......@@ -96,11 +96,11 @@ start_server() {
# This correctly passes each element as a separate argument.
if [[ -n "$profile_dir" ]]; then
# Start server with profiling enabled
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
else
# Start server without profiling
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
VLLM_SERVER_DEV_MODE=1 \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
fi
local server_pid=$!
......@@ -139,7 +139,7 @@ run_benchmark() {
echo "vllm_log: $vllm_log"
echo
rm -f $vllm_log
pkill -if vllm
pkill -if "vllm serve" || true
echo "starting server..."
# Call start_server without a profile_dir to avoid profiling overhead
......@@ -232,7 +232,7 @@ run_benchmark() {
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
pkill -if vllm
pkill -if "vllm serve" || true
sleep 10
echo "===================="
return 0
......@@ -308,6 +308,6 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
else
echo "No configuration met the latency requirements. Skipping final profiling run."
fi
pkill -if vllm
pkill -if "vllm serve" || true
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
......@@ -8,7 +8,6 @@ import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Optional, Union
import aiohttp
import huggingface_hub.constants
......@@ -28,13 +27,13 @@ class RequestFuncInput:
prompt_len: int
output_len: int
model: str
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict | list[dict]] = None
model_name: str | None = None
logprobs: int | None = None
extra_body: dict | None = None
multi_modal_content: dict | list[dict] | None = None
ignore_eos: bool = False
language: Optional[str] = None
request_id: Optional[str] = None
language: str | None = None
request_id: str | None = None
@dataclass
......@@ -52,7 +51,7 @@ class RequestFuncOutput:
async def async_request_tgi(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
......@@ -133,7 +132,7 @@ async def async_request_tgi(
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
......@@ -204,7 +203,7 @@ async def async_request_trt_llm(
async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(("completions", "profile")), (
......@@ -267,7 +266,7 @@ async def async_request_deepspeed_mii(
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(("completions", "profile")), (
......@@ -367,7 +366,7 @@ async def async_request_openai_completions(
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(("chat/completions", "profile")), (
......@@ -476,7 +475,7 @@ async def async_request_openai_chat_completions(
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
......@@ -610,7 +609,7 @@ def get_tokenizer(
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path
):
......@@ -621,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
except ImportError as e:
raise ImportError(
"MistralTokenizer requires vllm package.\n"
......
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode.
This benchmark runs the same workload twice:
1. With VLLM_BATCH_INVARIANT=0 (baseline)
2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode)
And reports the timing and throughput metrics for comparison.
Environment variables:
VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B")
VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek)
VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128)
VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5)
VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024)
VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048)
VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128)
VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0)
VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4)
VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120)
VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN)
Example usage:
# Benchmark qwen3 (default)
python benchmarks/benchmark_batch_invariance.py
# Benchmark deepseek with 8 GPUs
VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\
python benchmarks/benchmark_batch_invariance.py
# Quick test with fewer trials
VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\
python benchmarks/benchmark_batch_invariance.py
"""
import contextlib
import os
import random
import time
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
"""Generate a random prompt for benchmarking."""
prompt_templates = [
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
base_prompt = random.choice(prompt_templates)
if max_words < min_words:
max_words = min_words
target_words = random.randint(min_words, max_words)
if target_words > 50:
padding_text = (
" This is an interesting topic that deserves more explanation. "
* (target_words // 50)
)
base_prompt = base_prompt + padding_text
return base_prompt
def run_benchmark_with_batch_invariant(
model: str,
tp_size: int,
max_batch_size: int,
num_trials: int,
min_prompt: int,
max_prompt: int,
max_tokens: int,
temperature: float,
gpu_mem_util: float,
max_model_len: int,
backend: str,
batch_invariant: bool,
seed: int = 12345,
) -> dict:
"""
Run the benchmark with the specified configuration.
Returns a dict with timing and throughput metrics.
"""
random.seed(seed)
# Set environment variables
os.environ["VLLM_ATTENTION_BACKEND"] = backend
if batch_invariant:
os.environ["VLLM_BATCH_INVARIANT"] = "1"
else:
os.environ["VLLM_BATCH_INVARIANT"] = "0"
print(f"\n{'=' * 80}")
print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}")
print(f" Model: {model}")
print(f" TP Size: {tp_size}")
print(f" Backend: {backend}")
print(f" Max Batch Size: {max_batch_size}")
print(f" Trials: {num_trials}")
print(f" Max Tokens: {max_tokens}")
print(f"{'=' * 80}\n")
sampling = SamplingParams(
temperature=temperature,
top_p=0.95,
max_tokens=max_tokens,
seed=20240919,
)
needle_prompt = "There once was a "
llm = None
try:
# Create LLM engine
start_init = time.perf_counter()
llm = LLM(
model=model,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
)
init_time = time.perf_counter() - start_init
print(f"Engine initialization time: {init_time:.2f}s\n")
# Generate baseline
print("Generating baseline (warmup)...")
baseline_out = llm.generate([needle_prompt], sampling)
assert len(baseline_out) == 1
baseline_text = baseline_out[0].outputs[0].text
print(f"Baseline output: '{baseline_text[:50]}...'\n")
# Run trials and measure timing
trial_times: list[float] = []
total_tokens = 0
total_prompts = 0
for trial in range(num_trials):
# Create a batch
prompts: list[str] = []
batch_size = random.randint(max_batch_size // 2, max_batch_size)
needle_pos = random.randint(0, batch_size - 1)
for i in range(batch_size):
if i == needle_pos:
prompts.append(needle_prompt)
else:
prompts.append(_random_prompt(min_prompt, max_prompt))
# Measure time for this trial
start_time = time.perf_counter()
outputs = llm.generate(prompts, sampling)
trial_time = time.perf_counter() - start_time
trial_times.append(trial_time)
total_prompts += len(prompts)
# Count tokens
for output in outputs:
if output.outputs:
total_tokens += len(output.outputs[0].token_ids)
print(
f"Trial {trial + 1}/{num_trials}: "
f"batch_size={batch_size}, "
f"time={trial_time:.2f}s"
)
# Verify needle output still matches
needle_output = outputs[needle_pos]
assert needle_output.prompt == needle_prompt
# Compute statistics
avg_time = sum(trial_times) / len(trial_times)
min_time = min(trial_times)
max_time = max(trial_times)
throughput = total_tokens / sum(trial_times)
prompts_per_sec = total_prompts / sum(trial_times)
print(f"\n{'=' * 80}")
print("RESULTS:")
print(f" Average time per trial: {avg_time:.2f}s")
print(f" Min time: {min_time:.2f}s")
print(f" Max time: {max_time:.2f}s")
print(f" Total tokens generated: {total_tokens}")
print(f" Total prompts processed: {total_prompts}")
print(f" Throughput: {throughput:.2f} tokens/s")
print(f" Prompts/s: {prompts_per_sec:.2f}")
print(f"{'=' * 80}\n")
return {
"init_time": init_time,
"avg_time": avg_time,
"min_time": min_time,
"max_time": max_time,
"total_tokens": total_tokens,
"total_prompts": total_prompts,
"throughput": throughput,
"prompts_per_sec": prompts_per_sec,
"trial_times": trial_times,
}
finally:
# Cleanup
if llm is not None:
with contextlib.suppress(Exception):
llm.shutdown()
def main():
# Check platform support
if not (current_platform.is_cuda() and current_platform.has_device_capability(90)):
print("ERROR: Requires CUDA and >= Hopper (SM90)")
print(f"Current platform: {current_platform.device_type}")
if current_platform.is_cuda():
print(f"Device capability: {current_platform.get_device_capability()}")
return 1
# Read configuration from environment
model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1"))
max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128"))
num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5"))
min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024"))
max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048"))
max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128"))
temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0"))
gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4"))
max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120"))
backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN")
print("\n" + "=" * 80)
print("VLLM BATCH INVARIANCE BENCHMARK")
print("=" * 80)
print("\nConfiguration:")
print(f" Model: {model}")
print(f" Tensor Parallel Size: {tp_size}")
print(f" Attention Backend: {backend}")
print(f" Max Batch Size: {max_batch_size}")
print(f" Number of Trials: {num_trials}")
print(f" Prompt Length Range: {min_prompt}-{max_prompt} words")
print(f" Max Tokens to Generate: {max_tokens}")
print(f" Temperature: {temperature}")
print(f" GPU Memory Utilization: {gpu_mem_util}")
print(f" Max Model Length: {max_model_len}")
print("=" * 80)
# Run benchmark WITHOUT batch invariance (baseline)
print("\n" + "=" * 80)
print("PHASE 1: Running WITHOUT batch invariance (baseline)")
print("=" * 80)
baseline_results = run_benchmark_with_batch_invariant(
model=model,
tp_size=tp_size,
max_batch_size=max_batch_size,
num_trials=num_trials,
min_prompt=min_prompt,
max_prompt=max_prompt,
max_tokens=max_tokens,
temperature=temperature,
gpu_mem_util=gpu_mem_util,
max_model_len=max_model_len,
backend=backend,
batch_invariant=False,
)
# Run benchmark WITH batch invariance
print("\n" + "=" * 80)
print("PHASE 2: Running WITH batch invariance")
print("=" * 80)
batch_inv_results = run_benchmark_with_batch_invariant(
model=model,
tp_size=tp_size,
max_batch_size=max_batch_size,
num_trials=num_trials,
min_prompt=min_prompt,
max_prompt=max_prompt,
max_tokens=max_tokens,
temperature=temperature,
gpu_mem_util=gpu_mem_util,
max_model_len=max_model_len,
backend=backend,
batch_invariant=True,
)
# Compare results
print("\n" + "=" * 80)
print("COMPARISON: Batch Invariance vs Baseline")
print("=" * 80)
init_overhead_pct = (
(batch_inv_results["init_time"] - baseline_results["init_time"])
/ baseline_results["init_time"]
* 100
)
time_overhead_pct = (
(batch_inv_results["avg_time"] - baseline_results["avg_time"])
/ baseline_results["avg_time"]
* 100
)
throughput_change_pct = (
(batch_inv_results["throughput"] - baseline_results["throughput"])
/ baseline_results["throughput"]
* 100
)
print("\nInitialization Time:")
print(f" Baseline: {baseline_results['init_time']:.2f}s")
print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s")
print(f" Overhead: {init_overhead_pct:+.2f}%")
print("\nAverage Trial Time:")
print(f" Baseline: {baseline_results['avg_time']:.2f}s")
print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s")
print(f" Overhead: {time_overhead_pct:+.2f}%")
print("\nThroughput (tokens/s):")
print(f" Baseline: {baseline_results['throughput']:.2f}")
print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}")
print(f" Change: {throughput_change_pct:+.2f}%")
print("\nPrompts/s:")
print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}")
print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}")
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
if time_overhead_pct > 0:
print(
f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% "
"overhead"
)
else:
print(
f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% "
"faster (unexpected!)"
)
if abs(throughput_change_pct) < 1.0:
print("Throughput difference is negligible (< 1%)")
elif throughput_change_pct < 0:
print(
f"Throughput decreased by {-throughput_change_pct:.1f}% "
"with batch invariance"
)
else:
print(
f"Throughput increased by {throughput_change_pct:.1f}% "
"with batch invariance (unexpected!)"
)
print("=" * 80 + "\n")
return 0
if __name__ == "__main__":
exit(main())
......@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from benchmark_utils import TimeCollector
from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.core.block_pool import BlockPool
......
......@@ -46,7 +46,7 @@ import time
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
......
......@@ -5,9 +5,9 @@ import time
from unittest import mock
import numpy as np
from benchmark_utils import TimeCollector
from tabulate import tabulate
from benchmark_utils import TimeCollector
from vllm.config import (
CacheConfig,
DeviceConfig,
......@@ -19,7 +19,7 @@ from vllm.config import (
VllmConfig,
)
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
......@@ -108,7 +108,10 @@ def benchmark_batched_propose(args):
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
)
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
......@@ -164,7 +167,7 @@ def invoke_main() -> None:
)
parser.add_argument(
"--batched", action="store_true", help="consider time to prepare batch"
) # noqa: E501
)
parser.add_argument(
"--num-iteration",
type=int,
......
......@@ -32,18 +32,15 @@ import dataclasses
import json
import random
import time
from typing import Optional
from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
# import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.tokenizers import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
......@@ -85,7 +82,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
# Remove the special tokens.
return random.choices(
[v for k, v in vocab.items() if k not in all_special_ids],
[v for v in vocab.values() if v not in all_special_ids],
k=length,
)
......@@ -95,7 +92,7 @@ def sample_requests_from_dataset(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: tuple[int, int],
fixed_output_len: Optional[int],
fixed_output_len: int | None,
) -> list[Request]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
......@@ -143,7 +140,7 @@ def sample_requests_from_random(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: tuple[int, int],
fixed_output_len: Optional[int],
fixed_output_len: int | None,
prefix_len: int,
) -> list[Request]:
requests = []
......
......@@ -7,12 +7,11 @@ import dataclasses
import json
import random
import time
from typing import Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Select a equi-probable random priority
......@@ -24,7 +23,7 @@ def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
fixed_output_len: int | None,
) -> list[tuple[str, int, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
......
......@@ -31,28 +31,27 @@ import time
import uuid
import warnings
from collections.abc import AsyncGenerator
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Optional
import datasets
import numpy as np
import pandas as pd
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
RequestFuncInput,
RequestFuncOutput,
)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.tokenizers import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
try:
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
......@@ -317,7 +316,7 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: list[str],
selected_percentiles: list[float],
goodput_config_dict: Optional[dict[str, float]] = None,
goodput_config_dict: dict[str, float] | None = None,
) -> tuple[BenchmarkMetrics, list[int]]:
actual_output_lens: list[int] = []
total_input = 0
......@@ -437,9 +436,9 @@ async def benchmark(
selected_percentile_metrics: list[str],
selected_percentiles: list[str],
ignore_eos: bool,
max_concurrency: Optional[int],
max_concurrency: int | None,
structured_output_ratio: float,
goodput_config_dict: Optional[dict[str, float]] = None,
goodput_config_dict: dict[str, float] | None = None,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
......@@ -503,15 +502,9 @@ async def benchmark(
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
# This can be used once the minimum Python version is 3.10 or higher,
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else nullcontext()
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input, pbar=pbar)
......@@ -910,13 +903,13 @@ def create_argument_parser():
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument(
"--num-prompts",
......
......@@ -6,7 +6,7 @@ import math
import os
import time
from types import TracebackType
from typing import Any, Optional, Union
from typing import Any
def convert_to_pytorch_benchmark_format(
......@@ -92,7 +92,7 @@ class TimeCollector:
def __init__(self, scale: int) -> None:
self.cnt: int = 0
self._sum: int = 0
self._max: Optional[int] = None
self._max: int | None = None
self.scale = scale
self.start_time: int = time.monotonic_ns()
......@@ -104,13 +104,13 @@ class TimeCollector:
else:
self._max = max(self._max, v)
def avg(self) -> Union[float, str]:
def avg(self) -> float | str:
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
def max(self) -> Union[float, str]:
def max(self) -> float | str:
return self._max / self.scale if self._max else "N/A"
def dump_avg_max(self) -> list[Union[float, str]]:
def dump_avg_max(self) -> list[float | str]:
return [self.avg(), self.max()]
def __enter__(self) -> None:
......@@ -118,8 +118,8 @@ class TimeCollector:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.collect(time.monotonic_ns() - self.start_time)
......@@ -6,8 +6,7 @@ import copy
import itertools
import pickle as pkl
import time
from collections.abc import Iterable
from typing import Callable
from collections.abc import Callable, Iterable
import torch
import torch.utils.benchmark as TBenchmark
......@@ -16,7 +15,7 @@ from utils import make_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
......
......@@ -6,8 +6,7 @@ import copy
import itertools
import pickle as pkl
import time
from collections.abc import Iterable
from typing import Callable, Optional
from collections.abc import Callable, Iterable
import torch
import torch.utils.benchmark as TBenchmark
......@@ -17,9 +16,10 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul,
w8a8_triton_block_scaled_mm,
)
from vllm.utils import FlexibleArgumentParser, cdiv
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.math_utils import cdiv
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
......@@ -53,7 +53,7 @@ def bench_int8(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8
......@@ -108,7 +108,7 @@ def bench_fp8(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn
......@@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
......@@ -183,7 +183,7 @@ def bench(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[list[str]] = None,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
......@@ -201,7 +201,7 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(
dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
......
......@@ -55,9 +55,7 @@ benchmark() {
output_len=$2
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
......@@ -65,9 +63,7 @@ benchmark() {
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
......
......@@ -38,16 +38,12 @@ wait_for_server() {
launch_chunked_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
--port 8100 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
--port 8200 \
--max-model-len 10000 \
--enable-chunked-prefill \
......@@ -62,18 +58,14 @@ launch_chunked_prefill() {
launch_disagg_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
--model $model \
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
......
......@@ -5,11 +5,12 @@ import argparse
import asyncio
import logging
import os
import time
import uuid
from urllib.parse import urlparse
import aiohttp
from quart import Quart, Response, make_response, request
from rate_limiter import RateLimiter
from request_queue import RequestQueue
# Configure logging
logging.basicConfig(level=logging.INFO)
......@@ -24,26 +25,8 @@ def parse_args():
parser.add_argument(
"--timeout",
type=float,
default=300,
help="Timeout for backend service requests in seconds (default: 300)",
)
parser.add_argument(
"--max-concurrent",
type=int,
default=100,
help="Maximum concurrent requests to backend services (default: 100)",
)
parser.add_argument(
"--queue-size",
type=int,
default=500,
help="Maximum number of requests in the queue (default: 500)",
)
parser.add_argument(
"--rate-limit",
type=int,
default=40,
help="Maximum requests per second (default: 40)",
default=6 * 60 * 60,
help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--port",
......@@ -54,14 +37,32 @@ def parse_args():
parser.add_argument(
"--prefill-url",
type=str,
default="http://localhost:8100/v1/completions",
help="Prefill service endpoint URL",
default="http://localhost:8100",
help="Prefill service base URL (protocol + host[:port])",
)
parser.add_argument(
"--decode-url",
type=str,
default="http://localhost:8200/v1/completions",
help="Decode service endpoint URL",
default="http://localhost:8200",
help="Decode service base URL (protocol + host[:port])",
)
parser.add_argument(
"--kv-host",
type=str,
default="localhost",
help="Hostname or IP used by KV transfer (default: localhost)",
)
parser.add_argument(
"--prefill-kv-port",
type=int,
default=14579,
help="Prefill KV port (default: 14579)",
)
parser.add_argument(
"--decode-kv-port",
type=int,
default=14580,
help="Decode KV port (default: 14580)",
)
return parser.parse_args()
......@@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
MAX_CONCURRENT_REQUESTS = args.max_concurrent
REQUEST_QUEUE_SIZE = args.queue_size
RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url
PORT = args.port
app = Quart(__name__)
PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
# Initialize the rate limiter and request queue
rate_limiter = RateLimiter(RATE_LIMIT)
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
logger.info(
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
PREFILL_KV_ADDR,
DECODE_KV_ADDR,
)
app = Quart(__name__)
# Attach the configuration object to the application instance
# Attach the configuration object to the application instance so helper
# coroutines can read the resolved backend URLs and timeouts without using
# globals.
app.config.update(
{
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
"rate_limiter": rate_limiter,
"request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
"DECODE_KV_ADDR": DECODE_KV_ADDR,
}
)
# Start queue processing on app startup
@app.before_serving
async def startup():
"""Start request processing task when app starts serving"""
asyncio.create_task(request_queue.process())
async def forward_request(url, data):
"""Forward request to backend service with rate limiting and error handling"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
# Use rate limiter as context manager
async with (
rate_limiter,
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
):
try:
async with session.post(
url=url, json=data, headers=headers
) as response:
if response.status == 200:
# Stream response chunks
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
# Handle backend service errors
error_text = await response.text()
logger.error(
"Backend service error: %s - %s",
response.status,
error_text,
)
yield b'{"error": "Backend service error"}'
except aiohttp.ClientError as e:
# Handle connection errors
logger.error("Connection error to %s: %s", url, str(e))
yield b'{"error": "Service unavailable"}'
except asyncio.TimeoutError:
# Handle timeout errors
logger.error("Timeout connecting to %s", url)
yield b'{"error": "Service timeout"}'
def _normalize_base_url(url: str) -> str:
"""Remove any trailing slash so path joins behave predictably."""
return url.rstrip("/")
def _get_host_port(url: str) -> str:
"""Return the hostname:port portion for logging and KV headers."""
parsed = urlparse(url)
host = parsed.hostname or "localhost"
port = parsed.port
if port is None:
port = 80 if parsed.scheme == "http" else 443
return f"{host}:{port}"
PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
def _build_headers(request_id: str) -> dict[str, str]:
"""Construct the headers expected by vLLM's P2P disagg connector."""
headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
async def _run_prefill(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{PREFILL_BASE}{request_path}"
start_ts = time.perf_counter()
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
try:
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
raise RuntimeError(
f"Prefill backend error {resp.status}: {error_text}"
)
await resp.read()
logger.info(
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
request_id,
resp.status,
time.perf_counter() - start_ts,
)
except asyncio.TimeoutError as exc:
raise RuntimeError(f"Prefill service timeout at {url}") from exc
except aiohttp.ClientError as exc:
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
async def _stream_decode(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{DECODE_BASE}{request_path}"
# Stream tokens from the decode service once the prefill stage has
# materialized KV caches on the target workers.
logger.info("[decode] start request_id=%s url=%s", request_id, url)
try:
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
logger.error(
"Decode backend error %s - %s", resp.status, error_text
)
err_msg = (
'{"error": "Decode backend error ' + str(resp.status) + '"}'
)
yield err_msg.encode()
return
logger.info(
"[decode] streaming response request_id=%s status=%s",
request_id,
resp.status,
)
async for chunk_bytes in resp.content.iter_chunked(1024):
yield chunk_bytes
logger.info("[decode] finished streaming request_id=%s", request_id)
except asyncio.TimeoutError:
logger.error("Decode service timeout at %s", url)
yield b'{"error": "Decode service timeout"}'
except aiohttp.ClientError as exc:
logger.error("Decode service error at %s: %s", url, exc)
yield b'{"error": "Decode service unavailable"}'
async def process_request():
"""Process a single request through prefill and decode stages"""
......@@ -146,13 +206,27 @@ def main():
# Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
# Execute prefill stage
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
continue
# The request id encodes both KV socket addresses so the backend can
# shuttle tensors directly via NCCL once the prefill response
# completes.
request_id = (
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
)
headers = _build_headers(request_id)
await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
# Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator = _stream_decode(
request.path, original_request_data, headers, request_id
)
response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response
return response
......@@ -168,23 +242,10 @@ def main():
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting"""
# Create task for request processing
task = asyncio.create_task(process_request())
# Enqueue request or reject if queue is full
if not await request_queue.enqueue(task):
return Response(
response=b'{"error": "Server busy, try again later"}',
status=503,
content_type="application/json",
)
try:
# Return the response from the processing task
return await task
return await process_request()
except asyncio.CancelledError:
# Handle task cancellation (timeout or queue full)
logger.warning("Request cancelled due to timeout or queue full")
logger.warning("Request cancelled")
return Response(
response=b'{"error": "Request cancelled"}',
status=503,
......
......@@ -3,10 +3,9 @@
import pickle as pkl
import time
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
from typing import Callable, Optional
import torch
import torch.utils.benchmark as TBenchmark
......@@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]:
def unfused_int8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor],
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
):
# Norm
......@@ -68,7 +67,7 @@ def unfused_int8_impl(
def unfused_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor],
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
):
# Norm
......@@ -85,7 +84,7 @@ def unfused_fp8_impl(
def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: Optional[torch.Tensor],
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
# Disable DeepGEMM for this benchmark to use CUTLASS
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import torch
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
......@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
# Create random FP8 tensors
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
# Create quantized weight tensor
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Create scales
# Create weight scales
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
......@@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
* factor_for_scale
)
# SM90 CUTLASS requires row-major format for scales
if use_cutlass and current_platform.is_device_capability(90):
Bs = Bs.T.contiguous()
# Create W8A8BlockFp8LinearOp instance
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=use_cutlass,
use_aiter_and_is_supported=False,
)
def run():
if use_cutlass:
return apply_w8a8_block_fp8_linear(
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
)
else:
return apply_w8a8_block_fp8_linear(
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
)
return linear_op.apply(
input=A_ref,
weight=B,
weight_scale=Bs,
input_scale=None,
bias=None,
)
return run
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import copy
import itertools
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"mxfp4": dict(no_a_quant=False, enabled=True),
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _quant_weight_mxfp4(
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
b, forward_hadamard_matrix, method="abs_max"
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
return weight_hf_e2m1, weight_hf_scale_block
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
b, forward_hadamard_matrix, device
)
alpha = torch.tensor([1.0], device="cuda")
if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
def run():
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)
return run
# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs MXFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_mxfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
for had_size in [32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_mxfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)
print("Benchmark finished!")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import copy
import itertools
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
from vllm._custom_ops import fusedQuantizeNv
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"nvfp4": dict(no_a_quant=False, enabled=True),
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _quant_weight_nvfp4(
b: torch.Tensor,
forward_hadamard_matrix: torch.Tensor,
global_scale: torch.Tensor,
device: str,
M: int,
N: int,
K: int,
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
b, forward_hadamard_matrix, global_scale
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
-1, K // 16
)
return weight_hf_e2m1, weight_hf_scale_block
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
alpha = torch.tensor([1.0], device="cuda")
global_scale = torch.tensor([1.0], device="cuda")
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
b, forward_hadamard_matrix, global_scale, device, M, N, K
)
if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
a, forward_hadamard_matrix, global_scale
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
-1, K // 16
)
def run():
return ops.cutlass_scaled_fp4_mm(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
torch.bfloat16,
)
return run
# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
a, forward_hadamard_matrix, global_scale
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
-1, K // 16
)
return ops.cutlass_scaled_fp4_mm(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
torch.bfloat16,
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs NVFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_nvfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
for had_size in [16, 32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_nvfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)
print("Benchmark finished!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment