Commit 41199996 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.12.0' into v0.12.0-dev

parents 31021d81 4fd9d6a8
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
import time import time
...@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import ( ...@@ -14,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash, create_kv_caches_with_random_flash,
) )
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from typing import Optional, Union
import torch import torch
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
...@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module): ...@@ -21,8 +20,8 @@ class HuggingFaceRMSNorm(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
...@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module): ...@@ -41,7 +40,7 @@ class HuggingFaceRMSNorm(nn.Module):
def rmsnorm_naive( def rmsnorm_naive(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
...@@ -65,7 +64,7 @@ def rmsnorm_naive( ...@@ -65,7 +64,7 @@ def rmsnorm_naive(
def rmsnorm_flashinfer( def rmsnorm_flashinfer(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
orig_shape = x.shape orig_shape = x.shape
...@@ -89,7 +88,7 @@ def rmsnorm_flashinfer( ...@@ -89,7 +88,7 @@ def rmsnorm_flashinfer(
def rmsnorm_vllm( def rmsnorm_vllm(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: torch.Tensor | None = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
orig_shape = x.shape orig_shape = x.shape
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import accumulate import itertools
from typing import Optional
import nvtx
import torch import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
batch_size_range = [2**i for i in range(0, 8, 2)]
seq_len_range = [2**i for i in range(6, 10, 1)]
num_heads_range = [32, 48]
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
def benchmark_rope_kernels_multi_lora(
is_neox_style: bool, def get_benchmark(head_size, rotary_dim, is_neox_style, device):
batch_size: int, @triton.testing.perf_report(
seq_len: int, triton.testing.Benchmark(
num_heads: int, x_names=["batch_size", "seq_len", "num_heads"],
head_size: int, x_vals=[list(_) for _ in configs],
rotary_dim: Optional[int], line_arg="provider",
dtype: torch.dtype, line_vals=["torch", "flashinfer", "vllm"],
seed: int, line_names=["PyTorch", "FlashInfer", "vLLM"],
device: str, styles=[("blue", "-"), ("green", "-"), ("red", "-")],
max_position: int = 8192, ylabel="us",
base: float = 10000, plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
) -> None: args={},
current_platform.seed_everything(seed) )
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
) )
# non-batched RoPE takes only one scaling factor, we create multiple def benchmark(batch_size, seq_len, num_heads, provider):
# instances to simulate the same behavior dtype = torch.bfloat16
non_batched_ropes: list[RotaryEmbedding] = [] max_position = 8192
for scaling_factor in scaling_factors: base = 10000
non_batched_ropes.append( rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
get_rope( rope = rope.to(dtype=dtype, device=device)
head_size, cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
rotary_dim,
max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
base, query = torch.randn(
is_neox_style, (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
{"rope_type": "linear", "factor": (scaling_factor,)},
)
) )
key = torch.randn_like(query)
positions = torch.randint(0, max_position, (batch_size, seq_len)) quantiles = [0.5, 0.2, 0.8]
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache if provider == "torch":
# together and each query needs to find the right kv cache of its type ms, min_ms, max_ms = triton.testing.do_bench(
offset_map = torch.tensor( lambda: rope.forward_native(positions, query.clone(), key.clone()),
list( quantiles=quantiles,
accumulate(
[0]
+ [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
]
) )
) elif provider == "flashinfer":
) ms, min_ms, max_ms = triton.testing.do_bench(
query_types = torch.randint( lambda: torch.ops.vllm.flashinfer_rotary_embedding(
0, len(scaling_factors), (batch_size, seq_len), device=device positions,
) query.clone(),
# map query types to offsets key.clone(),
query_offsets = offset_map[query_types] head_size,
# the kernel takes flattened offsets cos_sin_cache,
flatten_offsets = query_offsets.flatten() is_neox_style,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
# batched queries of the same type together for non-batched RoPE return benchmark
queries = [query[query_types == i] for i in range(len(scaling_factors))]
keys = [key[query_types == i] for i in range(len(scaling_factors))]
packed_qkr = zip(queries, keys, non_batched_ropes)
# synchronize before start timing
torch.cuda.synchronize()
with nvtx.annotate("non-batched", color="yellow"):
for q, k, r in packed_qkr:
r.forward(positions, q, k)
torch.cuda.synchronize()
with nvtx.annotate("batched", color="green"):
batched_rope.forward(positions, query, key, flatten_offsets)
torch.cuda.synchronize()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -117,17 +95,12 @@ if __name__ == "__main__": ...@@ -117,17 +95,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
) )
parser.add_argument("--save-path", type=str, default="./configs/rope/")
args = parser.parse_args() args = parser.parse_args()
print(args)
benchmark_rope_kernels_multi_lora( # Get the benchmark function
is_neox_style=args.is_neox_style, benchmark = get_benchmark(
batch_size=args.batch_size, args.head_size, args.rotary_dim, args.is_neox_style, args.device
seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rotary_dim,
dtype=getattr(torch, args.dtype),
seed=args.seed,
device=args.device,
) )
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
...@@ -78,11 +78,11 @@ WEIGHT_SHAPES = { ...@@ -78,11 +78,11 @@ WEIGHT_SHAPES = {
} }
WEIGHT_SHAPES_MOE = { WEIGHT_SHAPES_MOE = {
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [ "mistralai/Mixtral-8x7B-Instruct-v0.1": [
[8, 2, 4096, 28672], [8, 2, 4096, 28672],
[8, 2, 14336, 4096], [8, 2, 14336, 4096],
], ],
"nm-testing/deepseekv2-lite": [ "deepseek-ai/DeepSeek-V2-Lite": [
[64, 6, 2048, 1408], [64, 6, 2048, 1408],
], ],
"ibm-granite/granite-3.0-1b-a400m": [ "ibm-granite/granite-3.0-1b-a400m": [
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive 3-way SiLU Benchmark Suite
This benchmark compares three SiLU implementations:
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
2. Triton Kernel - Triton-based implementation
The suite generates detailed performance comparisons including:
- Memory bandwidth utilization
- Speedup ratios (baseline vs optimized implementations)
- Performance across different expert configurations and token distributions
"""
from collections.abc import Callable from collections.abc import Callable
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -7,7 +21,7 @@ import numpy as np ...@@ -7,7 +21,7 @@ import numpy as np
import torch import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm_cuda, persistent_masked_m_silu_mul_quant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( ...@@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
num_parallel_tokens, num_parallel_tokens,
group_size: int = 128, group_size: int = 128,
eps: float = 1e-10, eps: float = 1e-10,
expert_offsets: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
...@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( ...@@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
# Parse generation strategies # Parse generation strategies
strategies = ["uniform", "max_t", "first_t"] strategies = ["random_imbalanced", "uniform", "max_t"]
def benchmark( def benchmark(
...@@ -195,15 +210,27 @@ def benchmark( ...@@ -195,15 +210,27 @@ def benchmark(
current_platform.seed_everything(42 + seed_offset) current_platform.seed_everything(42 + seed_offset)
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
if gen_strategy == "uniform": if gen_strategy == "random_imbalanced":
r = torch.rand(size=(E,), device="cuda")
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
mean = total_tokens // n_e
min_max = mean // ratio
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
e[0] = min_max
r = torch.rand(size=(E - 1,))
r /= r.sum()
r *= total_tokens - min_max
r = r.round().long()
e[1:] = r.to(device=device)
return e
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
elif gen_strategy == "uniform":
r = torch.rand(size=(E,))
r /= r.sum() r /= r.sum()
r *= total_tokens r *= total_tokens
tokens_per_expert = r.int() r = r.round().long()
tokens_per_expert = torch.minimum( tokens_per_expert = r
tokens_per_expert,
torch.ones((E,), device=r.device, dtype=torch.int) * T,
)
elif gen_strategy == "max_t": elif gen_strategy == "max_t":
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
tokens_per_expert.fill_(total_tokens / E) tokens_per_expert.fill_(total_tokens / E)
...@@ -226,8 +253,8 @@ def benchmark( ...@@ -226,8 +253,8 @@ def benchmark(
) )
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
# Benchmark # Benchmark
latencies: list[float] = [] latencies: list[float] = []
...@@ -281,40 +308,34 @@ def benchmark( ...@@ -281,40 +308,34 @@ def benchmark(
def create_comparison_plot( def create_comparison_plot(
ratio, cuda_times, baseline_times, config_labels, strategy_name, id ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
): ):
"""Create a comparison plot for a specific generation strategy""" fig, ax = plt.subplots(1, 1, figsize=(18, 6))
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
# Configure x-axis positions # Configure x-axis positions
x = np.arange(len(config_labels)) x = np.arange(len(config_labels))
width = 0.35 width = 0.25
# Execution Time plot (lower is better) # Execution Time plot (lower is better)
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
ax.bar( ax.bar(
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
)
ax.bar(
x + width / 2,
baseline_times,
width,
label="Baseline",
alpha=0.8,
color="orange",
) )
# Add speedup labels over each bar pair # Add speedup labels over each bar trio
for i in range(len(x)): for i in range(len(x)):
speedup = ratio[i] triton_v2_speedup = ratios[i][1] # triton/v2
max_height = max(cuda_times[i], baseline_times[i]) max_height = max(silu_v2_times[i], triton_times[i])
# Triton/V2 speedup
ax.text( ax.text(
x[i], x[i] + width / 2,
max_height + max_height * 0.02, max_height + max_height * 0.02,
f"{speedup:.2f}x", f"{triton_v2_speedup:.2f}x",
ha="center", ha="center",
va="bottom", va="bottom",
fontweight="bold", fontweight="bold",
fontsize=9, fontsize=8,
) )
ax.set_xlabel("Configuration") ax.set_xlabel("Configuration")
...@@ -332,56 +353,75 @@ def create_comparison_plot( ...@@ -332,56 +353,75 @@ def create_comparison_plot(
def create_combined_plot(all_results): def create_combined_plot(all_results):
"""Create a combined plot with all strategies in one PNG"""
num_strategies = len(all_results) num_strategies = len(all_results)
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
if num_strategies == 1: if num_strategies == 1:
axes = [axes] axes = [axes]
for idx, ( for idx, (
strategy_name, strategy_name,
ratio, all_ratios,
cuda_times, all_silu_v2_results,
baseline_times, all_triton_results,
config_labels, config_labels,
config_x_axis,
) in enumerate(all_results): ) in enumerate(all_results):
ax = axes[idx] ax = axes[idx]
# Flatten the nested results to get bandwidth percentages for plotting
silu_v2_bandwidths = []
triton_bandwidths = []
flat_ratios = []
for config_results in all_silu_v2_results:
for result in config_results:
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
for config_results in all_triton_results:
for result in config_results:
triton_bandwidths.append(result[3]) # bandwidth percentage
for config_ratios in all_ratios:
for ratio in config_ratios:
flat_ratios.append(ratio)
# Configure x-axis positions # Configure x-axis positions
x = np.arange(len(config_labels)) x = np.arange(len(config_labels))
width = 0.35 width = 0.25
# Execution Time plot (lower is better) # Bandwidth utilization plot (higher is better)
ax.bar( ax.bar(
x - width / 2, x,
cuda_times, silu_v2_bandwidths,
width, width,
label="CUDA Kernel", label="SiLU V2 (CUDA)",
alpha=0.8, alpha=0.8,
color="blue", color="blue",
) )
ax.bar( ax.bar(
x + width / 2, x + width,
baseline_times, triton_bandwidths,
width, width,
label="Baseline", label="Triton Kernel",
alpha=0.8, alpha=0.8,
color="orange", color="green",
) )
# Add speedup labels over each bar pair # Add speedup labels over each bar trio
for i in range(len(x)): for i in range(len(x)):
speedup = ratio[i] triton_v2_speedup = flat_ratios[i] # triton/v2
max_height = max(cuda_times[i], baseline_times[i]) max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
# Triton/V2 speedup
ax.text( ax.text(
x[i], x[i] + width / 2,
max_height + max_height * 0.02, max_height + max_height * 0.02,
f"{speedup:.2f}x", f"{triton_v2_speedup:.2f}x",
ha="center", ha="center",
va="bottom", va="bottom",
fontweight="bold", fontweight="bold",
fontsize=9, fontsize=8,
) )
ax.set_xlabel("Configuration") ax.set_xlabel("Configuration")
...@@ -395,7 +435,7 @@ def create_combined_plot(all_results): ...@@ -395,7 +435,7 @@ def create_combined_plot(all_results):
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
plt.tight_layout() plt.tight_layout()
filename = "../../silu_bench/silu_benchmark_combined.png" filename = "silu_benchmark_combined_3way.png"
plt.savefig(filename, dpi=300, bbox_inches="tight") plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show() plt.show()
...@@ -405,7 +445,9 @@ def create_combined_plot(all_results): ...@@ -405,7 +445,9 @@ def create_combined_plot(all_results):
outer_dim = 7168 outer_dim = 7168
configs = [ configs = [
# DeepSeekV3 Configs # DeepSeekV3 Configs
# (1, 56, 7168),
(8, 1024, 7168), (8, 1024, 7168),
# (32, 56, 7168),
# DeepSeekV3 Configs # DeepSeekV3 Configs
(32, 1024, 7168), (32, 1024, 7168),
# DeepSeekV3 Configs # DeepSeekV3 Configs
...@@ -417,6 +459,7 @@ num_warmups = 20 ...@@ -417,6 +459,7 @@ num_warmups = 20
strategy_descriptions = { strategy_descriptions = {
"uniform": "Uniform Random", "uniform": "Uniform Random",
"random_imbalanced": "Imbalanced Random",
"max_t": "Even Assignment", "max_t": "Even Assignment",
"first_t": "experts[0] = T, experts[1:] = 0", "first_t": "experts[0] = T, experts[1:] = 0",
} }
...@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies): ...@@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
print(f"Testing strategy: {strategy_descriptions[strategy]}") print(f"Testing strategy: {strategy_descriptions[strategy]}")
print(f"{'=' * 60}") print(f"{'=' * 60}")
# Collect benchmark data for both algorithms # Collect benchmark data for all three algorithms
config_labels = [] config_labels = []
config_x_axis = [] config_x_axis = []
all_cuda_results = [] all_silu_v2_results = []
all_baseline_results = [] all_triton_results = []
all_ratios = [] all_ratios = []
for E, T, H in configs: for E, T, H in configs:
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] total_tokens_config = []
for i in [8, 16, 32, 64, 128, 256, 512]:
if i <= T:
total_tokens_config.append(i * E)
config_x_axis.append(total_tokens_config) config_x_axis.append(total_tokens_config)
cuda_results = [] silu_v2_results = []
baseline_results = [] triton_results = []
ratios = [] ratios = []
for total_tokens in total_tokens_config: for total_tokens in total_tokens_config:
config_label = f"E={E},T={T},H={H},TT={total_tokens}" config_label = f"E={E},T={T},H={H},TT={total_tokens}"
config_labels.append(config_label) config_labels.append(config_label)
# CUDA kernel results # SiLU V2 (CUDA kernel) results
time_ms_cuda, gflops, gbps, perc = benchmark( time_ms_silu_v2, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_cuda, persistent_masked_m_silu_mul_quant,
E, E,
T, T,
H, H,
...@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies): ...@@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups, num_warmups=num_warmups,
gen_strategy=strategy, gen_strategy=strategy,
) )
cuda_results.append((time_ms_cuda, gflops, gbps, perc)) silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
# Baseline results # Triton kernel results
time_ms_triton, gflops, gbps, perc = benchmark( time_ms_triton, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_triton, silu_mul_fp8_quant_deep_gemm_triton,
E, E,
...@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies): ...@@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups, num_warmups=num_warmups,
gen_strategy=strategy, gen_strategy=strategy,
) )
baseline_results.append((time_ms_triton, gflops, gbps, perc)) triton_results.append((time_ms_triton, gflops, gbps, perc))
ratios.append(time_ms_triton / time_ms_cuda)
print(f"Completed: {config_label}") # Calculate speedup ratios (triton baseline / implementation)
all_cuda_results.append(cuda_results) triton_v2_ratio = time_ms_triton / time_ms_silu_v2
all_baseline_results.append(baseline_results) ratios.append(triton_v2_ratio)
print(
f"Completed: {config_label}:"
f" V2: {time_ms_silu_v2:.3f}ms,"
f" Triton: {time_ms_triton:.3f}ms"
)
all_silu_v2_results.append(silu_v2_results)
all_triton_results.append(triton_results)
all_ratios.append(ratios) all_ratios.append(ratios)
# Store results for combined plotting # Store results for combined plotting
...@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies): ...@@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
( (
strategy_descriptions[strategy], strategy_descriptions[strategy],
all_ratios, all_ratios,
all_cuda_results, all_silu_v2_results,
all_baseline_results, all_triton_results,
config_labels, config_labels,
config_x_axis, config_x_axis,
) )
...@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies): ...@@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
# Print summary table for this strategy # Print summary table for this strategy
print(f"\nSummary Table - {strategy_descriptions[strategy]}:") print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
print("-" * 60) print("-" * 90)
for i, (E, T, H) in enumerate(configs): for i, (E, T, H) in enumerate(configs):
speedup = baseline_results[i][0] / cuda_results[i][0] # Get the first result for each config (simplifying for summary)
v2_time = silu_v2_results[i][0]
triton_time = triton_results[i][0]
triton_v2_speedup = triton_time / v2_time
config_label = f"E={E:3d},T={T:4d},H={H:4d}" config_label = f"E={E:3d},T={T:4d},H={H:4d}"
print( print(
f"{config_label:<20} {cuda_results[i][0]:8.5f} " f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" f"{triton_v2_speedup:8.2f}x"
) )
...@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results): ...@@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
num_strategies = len(all_results) num_strategies = len(all_results)
num_configs = len(configs) num_configs = len(configs)
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
fig, axs = plt.subplots( fig, axs = plt.subplots(
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
) )
# Add main title to the entire figure # Add main title to the entire figure
fig.suptitle( fig.suptitle(
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", "Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
fontsize=16, fontsize=18,
fontweight="bold", fontweight="bold",
y=0.98, y=0.98,
) )
...@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results): ...@@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
( (
strategy_name, strategy_name,
all_ratios, all_ratios,
all_cuda_results, all_silu_v2_results,
all_baseline_results, all_triton_results,
config_labels, config_labels,
config_x_axis, config_x_axis,
) = result ) = result
...@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results): ...@@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
ratios = all_ratios[config_idx] ratios = all_ratios[config_idx]
total_tokens_values = config_x_axis[config_idx] total_tokens_values = config_x_axis[config_idx]
# Extract CUDA and Triton bandwidth percentages # Extract speedup ratios
cuda_bandwidth_percentages = [ triton_v2_ratios = [ratio for ratio in ratios]
result[3] for result in all_cuda_results[config_idx]
# Extract bandwidth percentages for all implementations
v2_bandwidth_percentages = [
result[3] for result in all_silu_v2_results[config_idx]
] ]
triton_bandwidth_percentages = [ triton_bandwidth_percentages = [
result[3] for result in all_baseline_results[config_idx] result[3] for result in all_triton_results[config_idx]
] ]
# Plot speedup ratios vs total tokens (left plot) # Plot speedup ratios vs total tokens (left plot)
ax_speedup.plot( ax_speedup.plot(
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 total_tokens_values,
triton_v2_ratios,
"go-",
linewidth=3,
markersize=8,
label="Triton/V2 Speedup",
) )
ax_speedup.set_title( ax_speedup.set_title(
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
fontsize=12, fontsize=12,
fontweight="bold", fontweight="bold",
) )
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
ax_speedup.legend(prop={"weight": "bold"})
ax_speedup.grid(True, alpha=0.3) ax_speedup.grid(True, alpha=0.3)
# Plot bandwidth utilization (right plot)
ax_bandwidth.plot( ax_bandwidth.plot(
total_tokens_values, total_tokens_values,
cuda_bandwidth_percentages, v2_bandwidth_percentages,
"ro-", "o-",
linewidth=3, linewidth=3,
markersize=8, markersize=8,
label="CUDA", label="SiLU V2",
color="blue",
) )
ax_bandwidth.plot( ax_bandwidth.plot(
total_tokens_values, total_tokens_values,
triton_bandwidth_percentages, triton_bandwidth_percentages,
"go-", "o-",
linewidth=3, linewidth=3,
markersize=8, markersize=8,
label="Triton", label="Triton",
color="green",
) )
ax_bandwidth.set_title( ax_bandwidth.set_title(
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
...@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results): ...@@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
for label in ax.get_xticklabels() + ax.get_yticklabels(): for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontweight("bold") label.set_fontweight("bold")
# Add value labels on speedup points # Add value labels on Triton/V2 speedup points
for x, y in zip(total_tokens_values, ratios): for x, y in zip(total_tokens_values, triton_v2_ratios):
ax_speedup.annotate( ax_speedup.annotate(
f"{y:.2f}x", f"{y:.2f}x",
(x, y), (x, y),
textcoords="offset points", textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=10,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
)
# Add value labels on CUDA bandwidth points
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=9,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
)
# Add value labels on Triton bandwidth points
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, -15), xytext=(0, -15),
ha="center", ha="center",
fontsize=9, fontsize=9,
...@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results): ...@@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
plt.tight_layout() plt.tight_layout()
plt.subplots_adjust(top=0.93) # Make room for main title plt.subplots_adjust(top=0.93) # Make room for main title
filename = "silu_benchmark_total_tokens.png" filename = "silu_benchmark_total_tokens_3way.png"
plt.savefig(filename, dpi=300, bbox_inches="tight") plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show() plt.show()
return filename return filename
# Create combined plot with all strategies # Create comprehensive 3-way comparison plots
combined_plot_filename = create_total_tokens_plot(all_results) combined_plot_filename = create_combined_plot(all_results)
total_tokens_plot_filename = create_total_tokens_plot(all_results)
print(f"\n{'=' * 60}") print(f"\n{'=' * 80}")
print("Benchmark Complete!") print("3-Way Benchmark Suite Complete!")
print(f"Generated combined plot: {combined_plot_filename}") print(f"Generated combined comparison plot: {combined_plot_filename}")
print(f"{'=' * 60}") print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
print("Compared: SiLU V2 (CUDA), and Triton implementations")
print(f"{'=' * 80}")
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
import csv import csv
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
from vllm.utils import round_up from vllm.utils.math_utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn
...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): ...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_decode( def benchmark_decode(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtypes: tuple[ quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int, batch_size: int,
max_seq_len: int, max_seq_len: int,
num_heads: tuple[int, int] = (64, 8), num_heads: tuple[int, int] = (64, 8),
...@@ -130,8 +127,8 @@ def benchmark_decode( ...@@ -130,8 +127,8 @@ def benchmark_decode(
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
for i in range(warmup): for i in range(warmup):
fn() fn()
......
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
import csv import csv
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
from vllm.utils import round_up from vllm.utils.math_utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn
...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): ...@@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_prefill( def benchmark_prefill(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtypes: tuple[ quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int, batch_size: int,
max_seq_len: int, max_seq_len: int,
num_heads: tuple[int, int] = (64, 8), num_heads: tuple[int, int] = (64, 8),
...@@ -142,8 +139,8 @@ def benchmark_prefill( ...@@ -142,8 +139,8 @@ def benchmark_prefill(
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
for i in range(warmup): for i in range(warmup):
fn() fn()
......
...@@ -14,11 +14,11 @@ import torch ...@@ -14,11 +14,11 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul, _w8a8_triton_block_scaled_mm,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
...@@ -83,7 +83,7 @@ def w8a8_block_matmul( ...@@ -83,7 +83,7 @@ def w8a8_block_matmul(
) )
if A.dtype == torch.float8_e4m3fn: if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul kernel = _w8a8_triton_block_scaled_mm
else: else:
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
...@@ -183,8 +183,8 @@ def benchmark_config( ...@@ -183,8 +183,8 @@ def benchmark_config(
run() run()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels. This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
Currently this just includes dense GEMMs and only works on Hopper GPUs. Currently, this just includes dense GEMMs and only works on Hopper GPUs.
## Setup ## Setup
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501 # ruff: noqa: E501
import time import time
...@@ -9,7 +8,7 @@ import torch ...@@ -9,7 +8,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_block_fp8_matmul, w8a8_triton_block_scaled_mm,
) )
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
...@@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import ( ...@@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
) )
def benchmark_shape(m: int, def benchmark_shape(
n: int, m: int,
k: int, n: int,
warmup: int = 100, k: int,
repeat: int = 10000, warmup: int = 100,
verbose: bool = False) -> dict: repeat: int = 10000,
verbose: bool = False,
) -> dict:
"""Benchmark all implementations for a specific (m, n, k) shape.""" """Benchmark all implementations for a specific (m, n, k) shape."""
if verbose: if verbose:
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
# Create test tensors # Create test tensors
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16 # Reference result in BF16
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -49,34 +50,39 @@ def benchmark_shape(m: int, ...@@ -49,34 +50,39 @@ def benchmark_shape(m: int,
# Pre-quantize A for all implementations # Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
A, block_size[1], column_major_scales=True) A, block_size[1], column_major_scales=True
)
# === DeepGEMM Implementation === # === DeepGEMM Implementation ===
def deepgemm_gemm(): def deepgemm_gemm():
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm), fp8_gemm_nt(
(B_deepgemm, B_scale_deepgemm), (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
C_deepgemm) )
return C_deepgemm return C_deepgemm
# === vLLM Triton Implementation === # === vLLM Triton Implementation ===
def vllm_triton_gemm(): def vllm_triton_gemm():
return w8a8_block_fp8_matmul(A_vllm, return w8a8_triton_block_scaled_mm(
B_vllm, A_vllm,
A_scale_vllm, B_vllm,
B_scale_vllm, A_scale_vllm,
block_size, B_scale_vllm,
output_dtype=torch.bfloat16) block_size,
output_dtype=torch.bfloat16,
)
# === vLLM CUTLASS Implementation === # === vLLM CUTLASS Implementation ===
def vllm_cutlass_gemm(): def vllm_cutlass_gemm():
return ops.cutlass_scaled_mm(A_vllm_cutlass, return ops.cutlass_scaled_mm(
B_vllm.T, A_vllm_cutlass,
scale_a=A_scale_vllm_cutlass, B_vllm.T,
scale_b=B_scale_vllm.T, scale_a=A_scale_vllm_cutlass,
out_dtype=torch.bfloat16) scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16,
)
# Run correctness check first # Run correctness check first
if verbose: if verbose:
...@@ -93,26 +99,23 @@ def benchmark_shape(m: int, ...@@ -93,26 +99,23 @@ def benchmark_shape(m: int,
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
print("vLLM Triton vs DeepGEMM difference: " print(
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") "vLLM Triton vs DeepGEMM difference: "
print("vLLM CUTLASS vs DeepGEMM difference: " f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") )
print(
"vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
)
# Benchmark implementations # Benchmark implementations
implementations = { implementations = {
"DeepGEMM": deepgemm_gemm, "DeepGEMM": deepgemm_gemm,
"vLLM Triton": vllm_triton_gemm, "vLLM Triton": vllm_triton_gemm,
"vLLM CUTLASS": vllm_cutlass_gemm "vLLM CUTLASS": vllm_cutlass_gemm,
} }
benchmark_results = { benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
"shape": {
"m": m,
"n": n,
"k": k
},
"implementations": {}
}
for name, func in implementations.items(): for name, func in implementations.items():
# Warmup # Warmup
...@@ -140,38 +143,36 @@ def benchmark_shape(m: int, ...@@ -140,38 +143,36 @@ def benchmark_shape(m: int,
"tflops": tflops, "tflops": tflops,
"gb_s": gb_s, "gb_s": gb_s,
"diff": { "diff": {
"DeepGEMM": "DeepGEMM": 0.0
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), if name == "DeepGEMM"
"Reference": else calc_diff(func(), C_deepgemm),
deepgemm_diff if name == "DeepGEMM" else "Reference": deepgemm_diff
(vllm_triton_diff if name == "DeepGEMM"
if name == "vLLM Triton" else vllm_cutlass_diff) else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
} },
} }
if verbose: if verbose:
print( print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
)
# Calculate speedups # Calculate speedups
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
for name, data in benchmark_results["implementations"].items(): for name, data in benchmark_results["implementations"].items():
if name != "DeepGEMM": if name != "DeepGEMM":
speedup = baseline / data["time_ms"] speedup = baseline / data["time_ms"]
benchmark_results["implementations"][name][ benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
"speedup_vs_deepgemm"] = speedup
if verbose: if verbose:
print(f"DeepGEMM is {1/speedup:.2f}x " print(
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") f"DeepGEMM is {1 / speedup:.2f}x "
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
)
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
"time_ms"] vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
"time_ms"]
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
benchmark_results["implementations"]["vLLM CUTLASS"][ benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
"speedup_vs_triton"] = cutlass_vs_triton cutlass_vs_triton
)
if verbose: if verbose:
print( print(
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
...@@ -183,8 +184,7 @@ def benchmark_shape(m: int, ...@@ -183,8 +184,7 @@ def benchmark_shape(m: int,
def format_table_row(values, widths): def format_table_row(values, widths):
"""Format a row with specified column widths.""" """Format a row with specified column widths."""
return "| " + " | ".join(f"{val:{w}}" return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
for val, w in zip(values, widths)) + " |"
def print_table(headers, rows, title=None): def print_table(headers, rows, title=None):
...@@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False): ...@@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
for result in all_results: for result in all_results:
shape = result["shape"] shape = result["shape"]
impl_data = result["implementations"]["DeepGEMM"] impl_data = result["implementations"]["DeepGEMM"]
deepgemm_rows.append([ deepgemm_rows.append(
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", [
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" shape["m"],
]) shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
]
)
print_table(deepgemm_headers, print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
deepgemm_rows,
title="DeepGEMM Implementation:")
# Print vLLM Triton table # Print vLLM Triton table
triton_headers = [ triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
]
triton_rows = [] triton_rows = []
for result in all_results: for result in all_results:
shape = result["shape"] shape = result["shape"]
impl_data = result["implementations"]["vLLM Triton"] impl_data = result["implementations"]["vLLM Triton"]
speedup = impl_data.get("speedup_vs_deepgemm", 1.0) speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
triton_rows.append([ triton_rows.append(
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", [
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", shape["m"],
format_speedup(speedup) shape["n"],
]) shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(speedup),
]
)
print_table(triton_headers, print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
triton_rows,
title="vLLM Triton Implementation:")
# Print vLLM CUTLASS table # Print vLLM CUTLASS table
cutlass_headers = [ cutlass_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", "m",
"vs Triton" "n",
"k",
"Time (μs)",
"TFLOPS",
"GB/s",
"vs DeepGEMM",
"vs Triton",
] ]
cutlass_rows = [] cutlass_rows = []
for result in all_results: for result in all_results:
...@@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False): ...@@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
impl_data = result["implementations"]["vLLM CUTLASS"] impl_data = result["implementations"]["vLLM CUTLASS"]
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
vs_triton = impl_data.get("speedup_vs_triton", 1.0) vs_triton = impl_data.get("speedup_vs_triton", 1.0)
cutlass_rows.append([ cutlass_rows.append(
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", [
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", shape["m"],
format_speedup(vs_deepgemm), shape["n"],
format_speedup(vs_triton) shape["k"],
]) f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton),
]
)
print_table(cutlass_headers, print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
cutlass_rows,
title="vLLM CUTLASS Implementation:")
# Calculate and print averages # Calculate and print averages
print("\n===== AVERAGE PERFORMANCE =====") print("\n===== AVERAGE PERFORMANCE =====")
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
avg_metrics = { avg_metrics = {
impl: { impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
"tflops": 0,
"gb_s": 0,
"time_ms": 0
}
for impl in implementations
} }
for result in all_results: for result in all_results:
...@@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False): ...@@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
avg_time = avg_metrics[impl]["time_ms"] / num_shapes avg_time = avg_metrics[impl]["time_ms"] / num_shapes
avg_rows.append([ avg_rows.append(
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
]) )
print_table(avg_headers, avg_rows) print_table(avg_headers, avg_rows)
...@@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False): ...@@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
avg_speedups = { avg_speedups = {
"DeepGEMM vs vLLM Triton": 0, "DeepGEMM vs vLLM Triton": 0,
"DeepGEMM vs vLLM CUTLASS": 0, "DeepGEMM vs vLLM CUTLASS": 0,
"vLLM CUTLASS vs vLLM Triton": 0 "vLLM CUTLASS vs vLLM Triton": 0,
} }
for result in all_results: for result in all_results:
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
"time_ms"]
avg_speedups[ avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups[ avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time vllm_triton_time / vllm_cutlass_time
avg_speedups[ )
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
print("\n===== AVERAGE SPEEDUPS =====") print("\n===== AVERAGE SPEEDUPS =====")
speedup_headers = ["Comparison", "Speedup"] speedup_headers = ["Comparison", "Speedup"]
...@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False): ...@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
for result in all_results: for result in all_results:
for impl in implementations: for impl in implementations:
avg_diff[impl] += result["implementations"][impl]["diff"][ avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
"Reference"]
diff_headers = ["Implementation", "Avg Diff vs Reference"] diff_headers = ["Implementation", "Avg Diff vs Reference"]
diff_rows = [] diff_rows = []
......
...@@ -11,7 +11,7 @@ import regex as re ...@@ -11,7 +11,7 @@ import regex as re
import seaborn as sns import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement from torch.utils.benchmark import Measurement as TMeasurement
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
from collections.abc import Iterable from collections.abc import Callable, Iterable
from typing import Any, Callable, Optional from typing import Any
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -55,7 +55,7 @@ class Bench: ...@@ -55,7 +55,7 @@ class Bench:
def __init__( def __init__(
self, self,
cuda_graph_params: Optional[CudaGraphBenchParams], cuda_graph_params: CudaGraphBenchParams | None,
label: str, label: str,
sub_label: str, sub_label: str,
description: str, description: str,
......
...@@ -55,6 +55,10 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 ...@@ -55,6 +55,10 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75
---------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------
``` ```
If you run with `--warmup-step`, the summary will also include `warmup_runtime_sec`
and `total_runtime_incl_warmup_sec` (while `runtime_sec` continues to reflect the
benchmark-only runtime so the reported throughput stays comparable).
### JSON configuration file for synthetic conversations generation ### JSON configuration file for synthetic conversations generation
The input flag `--input-file` is used to determine the input conversations for the benchmark.<br/> The input flag `--input-file` is used to determine the input conversations for the benchmark.<br/>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from statistics import mean from statistics import mean
from typing import Any, NamedTuple, Optional, Union from typing import Any, NamedTuple
import numpy as np # type: ignore import numpy as np # type: ignore
import pandas as pd # type: ignore import pandas as pd # type: ignore
...@@ -11,6 +11,7 @@ from bench_utils import ( ...@@ -11,6 +11,7 @@ from bench_utils import (
Color, Color,
logger, logger,
) )
from tqdm import tqdm
from transformers import AutoTokenizer # type: ignore from transformers import AutoTokenizer # type: ignore
# Conversation ID is a string (e.g: "UzTK34D") # Conversation ID is a string (e.g: "UzTK34D")
...@@ -35,8 +36,8 @@ class Distribution(ABC): ...@@ -35,8 +36,8 @@ class Distribution(ABC):
class UniformDistribution(Distribution): class UniformDistribution(Distribution):
def __init__( def __init__(
self, self,
min_val: Union[int, float], min_val: int | float,
max_val: Union[int, float], max_val: int | float,
is_integer: bool = True, is_integer: bool = True,
) -> None: ) -> None:
self.min_val = min_val self.min_val = min_val
...@@ -56,7 +57,7 @@ class UniformDistribution(Distribution): ...@@ -56,7 +57,7 @@ class UniformDistribution(Distribution):
class ConstantDistribution(Distribution): class ConstantDistribution(Distribution):
def __init__(self, value: Union[int, float]) -> None: def __init__(self, value: int | float) -> None:
self.value = value self.value = value
self.max_val = value self.max_val = value
...@@ -68,7 +69,7 @@ class ConstantDistribution(Distribution): ...@@ -68,7 +69,7 @@ class ConstantDistribution(Distribution):
class ZipfDistribution(Distribution): class ZipfDistribution(Distribution):
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: def __init__(self, alpha: float, max_val: int | None = None) -> None:
self.alpha = alpha self.alpha = alpha
self.max_val = max_val self.max_val = max_val
...@@ -83,7 +84,7 @@ class ZipfDistribution(Distribution): ...@@ -83,7 +84,7 @@ class ZipfDistribution(Distribution):
class PoissonDistribution(Distribution): class PoissonDistribution(Distribution):
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: def __init__(self, alpha: float, max_val: int | None = None) -> None:
self.alpha = alpha self.alpha = alpha
self.max_val = max_val self.max_val = max_val
...@@ -100,11 +101,11 @@ class PoissonDistribution(Distribution): ...@@ -100,11 +101,11 @@ class PoissonDistribution(Distribution):
class LognormalDistribution(Distribution): class LognormalDistribution(Distribution):
def __init__( def __init__(
self, self,
mean: Optional[float] = None, mean: float | None = None,
sigma: Optional[float] = None, sigma: float | None = None,
average: Optional[int] = None, average: int | None = None,
median_ratio: Optional[float] = None, median_ratio: float | None = None,
max_val: Optional[int] = None, max_val: int | None = None,
) -> None: ) -> None:
self.average = average self.average = average
self.median_ratio = median_ratio self.median_ratio = median_ratio
...@@ -417,6 +418,10 @@ def generate_conversations( ...@@ -417,6 +418,10 @@ def generate_conversations(
data = file.read() data = file.read()
tokens_in_file = tokenizer.encode(data, add_special_tokens=False) tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
list_of_tokens.extend(tokens_in_file) list_of_tokens.extend(tokens_in_file)
logger.info(
f"Loaded {len(tokens_in_file)} tokens from file {filename}, "
f"total tokens so far: {len(list_of_tokens)}"
)
conversations: ConversationsMap = {} conversations: ConversationsMap = {}
conv_id = 0 conv_id = 0
...@@ -449,18 +454,25 @@ def generate_conversations( ...@@ -449,18 +454,25 @@ def generate_conversations(
) )
base_offset += common_prefix_tokens base_offset += common_prefix_tokens
for conv_id in range(args.num_conversations): for conv_id in tqdm(
range(args.num_conversations),
total=args.num_conversations,
desc="Generating conversations",
unit="conv",
):
# Generate a single conversation # Generate a single conversation
messages: MessagesList = [] messages: MessagesList = []
nturns = turn_count[conv_id] nturns = turn_count[conv_id]
# User prompt token count per turn (with lower limit) # User prompt token count per turn (with lower limit)
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int)
input_token_count = np.maximum(input_token_count, base_prompt_token_count) input_token_count = np.maximum(input_token_count, base_prompt_token_count)
# Assistant answer token count per turn (with lower limit) # Assistant answer token count per turn (with lower limit)
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype(
int
)
output_token_count = np.maximum(output_token_count, 1) output_token_count = np.maximum(output_token_count, 1)
user_turn = True user_turn = True
......
...@@ -13,7 +13,7 @@ from datetime import datetime ...@@ -13,7 +13,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from statistics import mean from statistics import mean
from typing import NamedTuple, Optional, Union from typing import NamedTuple
import aiohttp # type: ignore import aiohttp # type: ignore
import numpy as np # type: ignore import numpy as np # type: ignore
...@@ -46,15 +46,16 @@ class ConversationSampling(str, Enum): ...@@ -46,15 +46,16 @@ class ConversationSampling(str, Enum):
class ClientArgs(NamedTuple): class ClientArgs(NamedTuple):
seed: int seed: int
max_num_requests: Optional[int] max_num_requests: int | None
skip_first_turn: bool skip_first_turn: bool
max_turns: Optional[int] max_turns: int | None
max_active_conversations: int max_active_conversations: int
verbose: bool verbose: bool
print_content: bool print_content: bool
verify_output: bool verify_output: bool
conversation_sampling: ConversationSampling conversation_sampling: ConversationSampling
request_rate: float request_rate: float
max_retries: int
class RequestArgs(NamedTuple): class RequestArgs(NamedTuple):
...@@ -63,6 +64,7 @@ class RequestArgs(NamedTuple): ...@@ -63,6 +64,7 @@ class RequestArgs(NamedTuple):
stream: bool stream: bool
limit_min_tokens: int # Use negative value for no limit limit_min_tokens: int # Use negative value for no limit
limit_max_tokens: int # Use negative value for no limit limit_max_tokens: int # Use negative value for no limit
timeout_sec: int
class BenchmarkArgs(NamedTuple): class BenchmarkArgs(NamedTuple):
...@@ -109,9 +111,9 @@ class RequestStats(NamedTuple): ...@@ -109,9 +111,9 @@ class RequestStats(NamedTuple):
class MetricStats: class MetricStats:
def __init__(self) -> None: def __init__(self) -> None:
self.min: Optional[float] = None self.min: float | None = None
self.max: Optional[float] = None self.max: float | None = None
self.avg: Optional[float] = None self.avg: float | None = None
self.sum = 0.0 self.sum = 0.0
self.count = 0 self.count = 0
...@@ -143,7 +145,7 @@ class MovingAverage: ...@@ -143,7 +145,7 @@ class MovingAverage:
self.index = 0 self.index = 0
self.sum = 0.0 self.sum = 0.0
self.count = 0 self.count = 0
self.avg: Optional[float] = None self.avg: float | None = None
def update(self, new_value: float) -> None: def update(self, new_value: float) -> None:
if self.count < self.window_size: if self.count < self.window_size:
...@@ -169,7 +171,7 @@ class MovingAverage: ...@@ -169,7 +171,7 @@ class MovingAverage:
class DebugStats: class DebugStats:
def __init__(self, logger: logging.Logger, window_size: int) -> None: def __init__(self, logger: logging.Logger, window_size: int) -> None:
self.logger = logger self.logger = logger
self.metrics: dict[str, Union[MovingAverage, MetricStats]] = { self.metrics: dict[str, MovingAverage | MetricStats] = {
"moving_avg_ttft_ms": MovingAverage(window_size), "moving_avg_ttft_ms": MovingAverage(window_size),
"moving_avg_tpot_ms": MovingAverage(window_size), "moving_avg_tpot_ms": MovingAverage(window_size),
"ttft_ms": MetricStats(), "ttft_ms": MetricStats(),
...@@ -198,14 +200,6 @@ class DebugStats: ...@@ -198,14 +200,6 @@ class DebugStats:
self.logger.info("-" * 50) self.logger.info("-" * 50)
# Must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :]
return text
def nanosec_to_millisec(value: float) -> float: def nanosec_to_millisec(value: float) -> float:
return value / 1000000.0 return value / 1000000.0
...@@ -220,8 +214,9 @@ async def send_request( ...@@ -220,8 +214,9 @@ async def send_request(
chat_url: str, chat_url: str,
model: str, model: str,
stream: bool = True, stream: bool = True,
min_tokens: Optional[int] = None, min_tokens: int | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
timeout_sec: int = 120,
) -> ServerResponse: ) -> ServerResponse:
payload = { payload = {
"model": model, "model": model,
...@@ -243,16 +238,22 @@ async def send_request( ...@@ -243,16 +238,22 @@ async def send_request(
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
# Calculate the timeout for the request # Calculate the timeout for the request
timeout_sec = 120
if max_tokens is not None: if max_tokens is not None:
# Assume TPOT of 200ms and use max_tokens to determine timeout # Assume TPOT of 200ms and use max_tokens to determine timeout
timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) token_based_timeout = int(max_tokens * 0.2)
if token_based_timeout > timeout_sec:
timeout_sec = token_based_timeout
logger.info(
"Using timeout of %ds based on max_tokens %d",
timeout_sec,
max_tokens,
)
timeout = aiohttp.ClientTimeout(total=timeout_sec) timeout = aiohttp.ClientTimeout(total=timeout_sec)
valid_response = True valid_response = True
ttft: Optional[float] = None ttft: float | None = None
chunk_delay: list[int] = [] chunk_delay: list[int] = []
latency: Optional[float] = None latency: float | None = None
first_chunk = "" first_chunk = ""
generated_text = "" generated_text = ""
...@@ -269,7 +270,7 @@ async def send_request( ...@@ -269,7 +270,7 @@ async def send_request(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
# End of stream # End of stream
latency = time.perf_counter_ns() - start_time latency = time.perf_counter_ns() - start_time
...@@ -364,7 +365,7 @@ async def send_turn( ...@@ -364,7 +365,7 @@ async def send_turn(
req_args: RequestArgs, req_args: RequestArgs,
verbose: bool, verbose: bool,
verify_output: bool, verify_output: bool,
) -> Optional[RequestStats]: ) -> RequestStats | None:
assert messages_to_use > 0 assert messages_to_use > 0
assert messages_to_use <= len(conversation_messages) assert messages_to_use <= len(conversation_messages)
...@@ -417,6 +418,7 @@ async def send_turn( ...@@ -417,6 +418,7 @@ async def send_turn(
req_args.stream, req_args.stream,
min_tokens, min_tokens,
max_tokens, max_tokens,
req_args.timeout_sec,
) )
if response.valid is False: if response.valid is False:
...@@ -526,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None: ...@@ -526,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None:
await asyncio.sleep(interval) await asyncio.sleep(interval)
async def exponential_backoff_sleep(
attempt_cnt: int,
base_rate: float = 1.0,
backoff_factor: float = 2.0,
jitter_fraction: float = 0.10,
verbose: bool = False,
) -> None:
# Sleep with exponential backoff and jitter after a failed request.
backoff_delay = base_rate * (backoff_factor**attempt_cnt)
jittered_delay = backoff_delay * (
1 + np.random.uniform(-jitter_fraction, jitter_fraction)
)
if verbose:
logger.info(f"Backoff for {jittered_delay:.3f} seconds...")
await asyncio.sleep(jittered_delay)
async def client_main( async def client_main(
args: ClientArgs, args: ClientArgs,
req_args: RequestArgs, req_args: RequestArgs,
...@@ -540,8 +561,11 @@ async def client_main( ...@@ -540,8 +561,11 @@ async def client_main(
f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501 f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501
) )
random.seed(args.seed) # Set unique seed per client (each client runs in its own process)
np.random.seed(args.seed) # Add 1 to ensure no client uses the same seed as the main process
client_seed = args.seed + client_id + 1
random.seed(client_seed)
np.random.seed(client_seed)
# Active conversations # Active conversations
active_convs: ConversationsMap = {} active_convs: ConversationsMap = {}
...@@ -644,7 +668,7 @@ async def client_main( ...@@ -644,7 +668,7 @@ async def client_main(
if args.verbose: if args.verbose:
curr_time_sec: float = time.perf_counter() curr_time_sec: float = time.perf_counter()
time_since_last_turn: Union[str, float] = "N/A" time_since_last_turn: str | float = "N/A"
if conv_id in time_of_last_turn: if conv_id in time_of_last_turn:
time_since_last_turn = round( time_since_last_turn = round(
curr_time_sec - time_of_last_turn[conv_id], 3 curr_time_sec - time_of_last_turn[conv_id], 3
...@@ -654,49 +678,62 @@ async def client_main( ...@@ -654,49 +678,62 @@ async def client_main(
) )
time_of_last_turn[conv_id] = curr_time_sec time_of_last_turn[conv_id] = curr_time_sec
success = True success = False
try: for attempt_cnt in range(args.max_retries + 1):
result = await send_turn( try:
session, exception = False
client_id, result = await send_turn(
conv_id, session,
messages, client_id,
current_turn, conv_id,
tokenizer, messages,
req_args, current_turn,
args.print_content, tokenizer,
args.verify_output, req_args,
) args.print_content,
if result is not None: args.verify_output,
result_queue.put(result) )
else: if result is not None:
# None means that the request failed, result_queue.put(result)
# and should not be added to the statistics. success = True
success = False break
num_failures += 1 else:
logger.warning(
logger.warning( f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 )
except asyncio.exceptions.TimeoutError:
exception = True
logger.error(
"%sClient %d - Timeout during conversation ID %s (turn: %d). "
"Base timeout is %ss (set with --request-timeout-sec), but the "
"effective timeout may be longer based on max_tokens. If this "
"is unexpected, consider increasing the timeout or checking "
"model performance.%s",
Color.RED,
client_id,
conv_id,
current_turn,
req_args.timeout_sec,
Color.RESET,
)
except Exception:
exception = True
logger.exception(
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
) )
# Remove the conversation (should not be used again) # Sleep before retry if not last attempt
active_convs.pop(conv_id) if not success and attempt_cnt < args.max_retries:
await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose)
except asyncio.exceptions.TimeoutError: if not success:
num_failures += 1 num_failures += 1
logger.exception( # Remove the conversation (should not be used again)
f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 active_convs.pop(conv_id)
) if exception:
break # Exit gracefully instead of raising an error break # Exit gracefully instead of raising an error
except Exception: else:
num_failures += 1
logger.exception(
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
)
break # Exit gracefully instead of raising an error
if success:
num_successes += 1 num_successes += 1
# Update the turns counter to include the LLM response # Update the turns counter to include the LLM response
...@@ -769,7 +806,7 @@ def get_client_config( ...@@ -769,7 +806,7 @@ def get_client_config(
"Number of conversations must be equal or larger than the number of clients" "Number of conversations must be equal or larger than the number of clients"
) )
max_req_per_client: Optional[int] = None max_req_per_client: int | None = None
if args.max_num_requests is not None: if args.max_num_requests is not None:
# Max number of requests per client # Max number of requests per client
req_per_client = args.max_num_requests // args.num_clients req_per_client = args.max_num_requests // args.num_clients
...@@ -811,6 +848,7 @@ def get_client_config( ...@@ -811,6 +848,7 @@ def get_client_config(
verify_output=args.verify_output, verify_output=args.verify_output,
conversation_sampling=args.conversation_sampling, conversation_sampling=args.conversation_sampling,
request_rate=args.request_rate, request_rate=args.request_rate,
max_retries=args.max_retries,
) )
if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: if args.limit_min_tokens > 0 or args.limit_max_tokens > 0:
...@@ -823,6 +861,9 @@ def get_client_config( ...@@ -823,6 +861,9 @@ def get_client_config(
"Invalid min/max tokens limits (min should not be larger than max)" "Invalid min/max tokens limits (min should not be larger than max)"
) )
if args.request_timeout_sec <= 0:
raise ValueError("Request timeout must be a positive number")
# Arguments for API requests # Arguments for API requests
chat_url = f"{args.url}/v1/chat/completions" chat_url = f"{args.url}/v1/chat/completions"
model_name = args.served_model_name if args.served_model_name else args.model model_name = args.served_model_name if args.served_model_name else args.model
...@@ -833,6 +874,7 @@ def get_client_config( ...@@ -833,6 +874,7 @@ def get_client_config(
stream=not args.no_stream, stream=not args.no_stream,
limit_min_tokens=args.limit_min_tokens, limit_min_tokens=args.limit_min_tokens,
limit_max_tokens=args.limit_max_tokens, limit_max_tokens=args.limit_max_tokens,
timeout_sec=args.request_timeout_sec,
) )
return client_args, req_args return client_args, req_args
...@@ -936,13 +978,13 @@ async def main_mp( ...@@ -936,13 +978,13 @@ async def main_mp(
f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501
) )
rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3) rps: str | float = round(len(client_metrics) / runtime_sec, 3)
if len(client_metrics) < (5 * bench_args.num_clients): if len(client_metrics) < (5 * bench_args.num_clients):
# Do not estimate the RPS if the number of samples is very low # Do not estimate the RPS if the number of samples is very low
# (threshold can be tuned if needed) # (threshold can be tuned if needed)
rps = "N/A" rps = "N/A"
runtime_left_sec: Union[str, float] = round( runtime_left_sec: str | float = round(
(runtime_sec / finished_convs) * (total_convs - finished_convs), 3 (runtime_sec / finished_convs) * (total_convs - finished_convs), 3
) )
if percent < 0.05: if percent < 0.05:
...@@ -976,7 +1018,7 @@ async def main_mp( ...@@ -976,7 +1018,7 @@ async def main_mp(
f"(is alive: {client.is_alive()}){Color.RESET}" f"(is alive: {client.is_alive()}){Color.RESET}"
) )
client.join(timeout=120) client.join(timeout=req_args.timeout_sec + 1)
if client.is_alive(): if client.is_alive():
logger.warning( logger.warning(
...@@ -1032,8 +1074,9 @@ def process_statistics( ...@@ -1032,8 +1074,9 @@ def process_statistics(
warmup_percentages: list[float], warmup_percentages: list[float],
test_params: dict, test_params: dict,
verbose: bool, verbose: bool,
gen_conv_args: Optional[GenConvArgs] = None, gen_conv_args: GenConvArgs | None = None,
excel_output: bool = False, excel_output: bool = False,
warmup_runtime_sec: float | None = None,
) -> None: ) -> None:
if len(client_metrics) == 0: if len(client_metrics) == 0:
logger.info("No samples to process") logger.info("No samples to process")
...@@ -1127,8 +1170,13 @@ def process_statistics( ...@@ -1127,8 +1170,13 @@ def process_statistics(
# Convert milliseconds to seconds # Convert milliseconds to seconds
runtime_sec = runtime_sec / 1000.0 runtime_sec = runtime_sec / 1000.0
requests_per_sec = float(len(df)) / runtime_sec requests_per_sec = float(len(df)) / runtime_sec
params = {
params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec} "runtime_sec": runtime_sec,
"requests_per_sec": requests_per_sec,
}
if warmup_runtime_sec is not None:
params["warmup_runtime_sec"] = warmup_runtime_sec
params["total_runtime_incl_warmup_sec"] = runtime_sec + warmup_runtime_sec
# Generate a summary of relevant metrics (and drop irrelevant data) # Generate a summary of relevant metrics (and drop irrelevant data)
df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose() df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose()
...@@ -1259,7 +1307,7 @@ async def main() -> None: ...@@ -1259,7 +1307,7 @@ async def main() -> None:
default=None, default=None,
help="The model name used in the API. " help="The model name used in the API. "
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the ``--model`` argument. ", "same as the `--model` argument. ",
) )
parser.add_argument( parser.add_argument(
...@@ -1342,6 +1390,16 @@ async def main() -> None: ...@@ -1342,6 +1390,16 @@ async def main() -> None:
help="Expected request rate (Poisson process) per client in requests/sec." help="Expected request rate (Poisson process) per client in requests/sec."
"Set to 0 for no delay between requests.", "Set to 0 for no delay between requests.",
) )
parser.add_argument(
"--max-retries",
type=int,
default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")),
help="Maximum number of retry attempts for timed-out requests. "
"Default is 0 (no retries). "
"Set to higher values to retry failed requests and maintain "
"fair workload distribution. "
"Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.",
)
parser.add_argument( parser.add_argument(
"--conversation-sampling", "--conversation-sampling",
type=ConversationSampling, type=ConversationSampling,
...@@ -1359,6 +1417,13 @@ async def main() -> None: ...@@ -1359,6 +1417,13 @@ async def main() -> None:
action="store_true", action="store_true",
help="Verify the LLM output (compare to the answers in the input JSON file)", help="Verify the LLM output (compare to the answers in the input JSON file)",
) )
parser.add_argument(
"--request-timeout-sec",
type=int,
default=120,
help="Timeout in seconds for each API request (default: 120). "
"Automatically increased if max tokens imply longer decoding.",
)
parser.add_argument( parser.add_argument(
"--no-stream", "--no-stream",
...@@ -1434,11 +1499,10 @@ async def main() -> None: ...@@ -1434,11 +1499,10 @@ async def main() -> None:
f"Invalid --warmup-percentage={args.warmup_percentage}" f"Invalid --warmup-percentage={args.warmup_percentage}"
) from None ) from None
# Set global seeds for main process
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
if not os.path.exists(args.model):
raise OSError(f"Path does not exist: {args.model}")
logger.info("Loading tokenizer") logger.info("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
...@@ -1494,6 +1558,8 @@ async def main() -> None: ...@@ -1494,6 +1558,8 @@ async def main() -> None:
url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop
) )
warmup_runtime_sec: float | None = None
# Warm-up step # Warm-up step
if args.warmup_step: if args.warmup_step:
# Only send a single user prompt from every conversation. # Only send a single user prompt from every conversation.
...@@ -1508,26 +1574,56 @@ async def main() -> None: ...@@ -1508,26 +1574,56 @@ async def main() -> None:
# all clients should finish their work before exiting # all clients should finish their work before exiting
warmup_bench_args = bench_args._replace(early_stop=False) warmup_bench_args = bench_args._replace(early_stop=False)
logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}") logger.info("%sWarmup start%s", Color.PURPLE, Color.RESET)
warmup_start_ns = time.perf_counter_ns()
conversations, _ = await main_mp( conversations, _ = await main_mp(
warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations
) )
logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}") warmup_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - warmup_start_ns)
logger.info(
"%sWarmup runtime: %.3f sec (%.3f ms)%s",
Color.PURPLE,
warmup_runtime_sec,
warmup_runtime_sec * 1000,
Color.RESET,
)
logger.info("%sWarmup done%s", Color.PURPLE, Color.RESET)
# Run the benchmark # Run the benchmark
start_time = time.perf_counter_ns() benchmark_start_ns = time.perf_counter_ns()
client_convs, client_metrics = await main_mp( client_convs, client_metrics = await main_mp(
client_args, req_args, bench_args, tokenizer, conversations client_args, req_args, bench_args, tokenizer, conversations
) )
total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time) benchmark_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - benchmark_start_ns)
# Calculate requests per second # Calculate requests per second
total_runtime_sec = total_runtime_ms / 1000.0 requests_per_sec = len(client_metrics) / benchmark_runtime_sec
rps = len(client_metrics) / total_runtime_sec benchmark_runtime_ms = benchmark_runtime_sec * 1000.0
logger.info( logger.info(
f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec" "%sAll clients finished, benchmark runtime: %.3f sec (%.3f ms), "
f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}" "requests per second: %.3f%s",
Color.GREEN,
benchmark_runtime_sec,
benchmark_runtime_ms,
requests_per_sec,
Color.RESET,
) )
if warmup_runtime_sec is not None:
total_runtime_sec = benchmark_runtime_sec + warmup_runtime_sec
logger.info(
"%sWarmup runtime: %.3f sec (%.3f ms)%s",
Color.GREEN,
warmup_runtime_sec,
warmup_runtime_sec * 1000,
Color.RESET,
)
logger.info(
"%sTotal runtime (including warmup): %.3f sec (%.3f ms)%s",
Color.GREEN,
total_runtime_sec,
total_runtime_sec * 1000,
Color.RESET,
)
# Benchmark parameters # Benchmark parameters
params = { params = {
...@@ -1552,6 +1648,7 @@ async def main() -> None: ...@@ -1552,6 +1648,7 @@ async def main() -> None:
verbose=args.verbose, verbose=args.verbose,
gen_conv_args=gen_conv_args, gen_conv_args=gen_conv_args,
excel_output=args.excel_output, excel_output=args.excel_output,
warmup_runtime_sec=warmup_runtime_sec,
) )
if args.output_file is not None: if args.output_file is not None:
......
...@@ -13,7 +13,7 @@ import argparse ...@@ -13,7 +13,7 @@ import argparse
import json import json
import random import random
from statistics import mean from statistics import mean
from typing import Any, Optional from typing import Any
import pandas as pd # type: ignore import pandas as pd # type: ignore
import tqdm # type: ignore import tqdm # type: ignore
...@@ -25,7 +25,7 @@ def has_non_english_chars(text: str) -> bool: ...@@ -25,7 +25,7 @@ def has_non_english_chars(text: str) -> bool:
def content_is_valid( def content_is_valid(
content: str, min_content_len: Optional[int], max_content_len: Optional[int] content: str, min_content_len: int | None, max_content_len: int | None
) -> bool: ) -> bool:
if min_content_len and len(content) < min_content_len: if min_content_len and len(content) < min_content_len:
return False return False
...@@ -37,7 +37,7 @@ def content_is_valid( ...@@ -37,7 +37,7 @@ def content_is_valid(
def print_stats( def print_stats(
conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None conversations: "list[dict[Any, Any]]", tokenizer: AutoTokenizer | None = None
) -> None: ) -> None:
# Collect statistics # Collect statistics
stats = [] stats = []
...@@ -109,12 +109,12 @@ def convert_sharegpt_to_openai( ...@@ -109,12 +109,12 @@ def convert_sharegpt_to_openai(
seed: int, seed: int,
input_file: str, input_file: str,
output_file: str, output_file: str,
max_items: Optional[int], max_items: int | None,
min_content_len: Optional[int] = None, min_content_len: int | None = None,
max_content_len: Optional[int] = None, max_content_len: int | None = None,
min_turns: Optional[int] = None, min_turns: int | None = None,
max_turns: Optional[int] = None, max_turns: int | None = None,
model: Optional[str] = None, model: str | None = None,
) -> None: ) -> None:
if min_turns and max_turns: if min_turns and max_turns:
assert min_turns <= max_turns assert min_turns <= max_turns
......
...@@ -2,4 +2,5 @@ numpy>=1.24 ...@@ -2,4 +2,5 @@ numpy>=1.24
pandas>=2.0.0 pandas>=2.0.0
aiohttp>=3.10 aiohttp>=3.10
transformers>=4.46 transformers>=4.46
xlsxwriter>=3.2.1 xlsxwriter>=3.2.1
\ No newline at end of file tqdm>=4.66
...@@ -5,7 +5,7 @@ import cProfile ...@@ -5,7 +5,7 @@ import cProfile
import pstats import pstats
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k. # A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000 LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
......
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true
\ No newline at end of file
...@@ -15,6 +15,7 @@ endif() ...@@ -15,6 +15,7 @@ endif()
# #
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc") include_directories("${CMAKE_SOURCE_DIR}/csrc")
...@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ...@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(ENABLE_AVX512VNNI OFF) set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif() endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND) elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2") list(APPEND CXX_COMPILE_FLAGS "-mavx2")
...@@ -188,31 +205,115 @@ else() ...@@ -188,31 +205,115 @@ else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
endif() endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
if (VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
FetchContent_Declare( # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
oneDNN # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
GIT_TAG v3.9 if(ASIMD_FOUND)
GIT_PROGRESS TRUE # Set number of parallel build processes
GIT_SHALLOW TRUE include(ProcessorCount)
) ProcessorCount(NPROC)
if(NOT NPROC)
set(NPROC 4)
endif()
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
# and create a local shim dir with it
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
find_library(OPEN_MP
NAMES gomp
PATHS ${VLLM_TORCH_GOMP_SHIM_DIR}
NO_DEFAULT_PATH
REQUIRED
)
# Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch
if (OPEN_MP)
set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}")
endif()
# Fetch and populate ACL
if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}")
message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}")
else()
message(STATUS "Downloading Arm Compute Library (ACL) from GitHub")
FetchContent_Populate(arm_compute
SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild"
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src"
GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git
GIT_TAG v52.6.0
GIT_SHALLOW TRUE
GIT_PROGRESS TRUE
)
set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}")
set(ACL_LIB_DIR "$ENV{ACL_ROOT_DIR}/build")
endif()
if(USE_ACL) # Build ACL with CMake
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
if(NOT ARM_COMPUTE_LIBRARY) set(CMAKE_BUILD_TYPE "Release")
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR") set(ARM_COMPUTE_ARCH "armv8.2-a")
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
set(ARM_COMPUTE_BUILD_TESTING "OFF")
set(_cmake_config_cmd
${CMAKE_COMMAND} -G Ninja -B build
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
-DCMAKE_BUILD_TYPE=Release
-DARM_COMPUTE_ARCH=armv8.2-a
-DARM_COMPUTE_ENABLE_ASSERTS=OFF
-DARM_COMPUTE_ENABLE_CPPTHREADS=OFF
-DARM_COMPUTE_ENABLE_OPENMP=ON
-DARM_COMPUTE_ENABLE_WERROR=OFF
-DARM_COMPUTE_BUILD_EXAMPLES=OFF
-DARM_COMPUTE_BUILD_TESTING=OFF)
set(_cmake_build_cmd
${CMAKE_COMMAND} --build build -- -j${NPROC}
)
execute_process(
COMMAND ${_cmake_config_cmd}
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
)
execute_process(
COMMAND ${_cmake_build_cmd}
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
RESULT_VARIABLE _acl_rc
)
if(NOT _acl_rc EQUAL 0)
message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).")
endif() endif()
set(ONEDNN_AARCH64_USE_ACL "ON") message(STATUS "Arm Compute Library (ACL) built successfully.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
# VLLM/oneDNN settings for ACL
set(ONEDNN_AARCH64_USE_ACL ON CACHE BOOL "" FORCE)
add_compile_definitions(VLLM_USE_ACL)
endif()
set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.")
if(FETCHCONTENT_SOURCE_DIR_ONEDNN)
message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}")
FetchContent_Declare(
oneDNN
SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN}
)
else()
message(STATUS "Downloading oneDNN from GitHub")
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.10
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
endif() endif()
set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_LIBRARY_TYPE "STATIC")
...@@ -229,7 +330,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -229,7 +330,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_VERBOSE "OFF") set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE})
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
target_include_directories( target_include_directories(
dnnl_ext dnnl_ext
...@@ -259,18 +363,19 @@ endif() ...@@ -259,18 +363,19 @@ endif()
# #
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp" "csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp" "csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp" "csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp" "csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp") "csrc/cpu/cpu_attn.cpp"
"csrc/cpu/scratchpad_manager.cpp"
"csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
...@@ -297,7 +402,7 @@ message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") ...@@ -297,7 +402,7 @@ message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
# Define extension targets # Define extension targets
# #
define_gpu_extension_target( define_extension_target(
_C _C
DESTINATION vllm DESTINATION vllm
LANGUAGE CXX LANGUAGE CXX
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment