Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501 # ruff: noqa: E501
import time import time
...@@ -9,7 +8,7 @@ import torch ...@@ -9,7 +8,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_block_fp8_matmul, w8a8_triton_block_scaled_mm,
) )
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
...@@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import ( ...@@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
) )
def benchmark_shape(m: int, def benchmark_shape(
n: int, m: int,
k: int, n: int,
warmup: int = 100, k: int,
repeat: int = 10000, warmup: int = 100,
verbose: bool = False) -> dict: repeat: int = 10000,
verbose: bool = False,
) -> dict:
"""Benchmark all implementations for a specific (m, n, k) shape.""" """Benchmark all implementations for a specific (m, n, k) shape."""
if verbose: if verbose:
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
# Create test tensors # Create test tensors
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16 # Reference result in BF16
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -49,34 +50,39 @@ def benchmark_shape(m: int, ...@@ -49,34 +50,39 @@ def benchmark_shape(m: int,
# Pre-quantize A for all implementations # Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
A, block_size[1], column_major_scales=True) A, block_size[1], column_major_scales=True
)
# === DeepGEMM Implementation === # === DeepGEMM Implementation ===
def deepgemm_gemm(): def deepgemm_gemm():
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm), fp8_gemm_nt(
(B_deepgemm, B_scale_deepgemm), (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
C_deepgemm) )
return C_deepgemm return C_deepgemm
# === vLLM Triton Implementation === # === vLLM Triton Implementation ===
def vllm_triton_gemm(): def vllm_triton_gemm():
return w8a8_block_fp8_matmul(A_vllm, return w8a8_triton_block_scaled_mm(
B_vllm, A_vllm,
A_scale_vllm, B_vllm,
B_scale_vllm, A_scale_vllm,
block_size, B_scale_vllm,
output_dtype=torch.bfloat16) block_size,
output_dtype=torch.bfloat16,
)
# === vLLM CUTLASS Implementation === # === vLLM CUTLASS Implementation ===
def vllm_cutlass_gemm(): def vllm_cutlass_gemm():
return ops.cutlass_scaled_mm(A_vllm_cutlass, return ops.cutlass_scaled_mm(
B_vllm.T, A_vllm_cutlass,
scale_a=A_scale_vllm_cutlass, B_vllm.T,
scale_b=B_scale_vllm.T, scale_a=A_scale_vllm_cutlass,
out_dtype=torch.bfloat16) scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16,
)
# Run correctness check first # Run correctness check first
if verbose: if verbose:
...@@ -93,26 +99,23 @@ def benchmark_shape(m: int, ...@@ -93,26 +99,23 @@ def benchmark_shape(m: int,
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
print("vLLM Triton vs DeepGEMM difference: " print(
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") "vLLM Triton vs DeepGEMM difference: "
print("vLLM CUTLASS vs DeepGEMM difference: " f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") )
print(
"vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
)
# Benchmark implementations # Benchmark implementations
implementations = { implementations = {
"DeepGEMM": deepgemm_gemm, "DeepGEMM": deepgemm_gemm,
"vLLM Triton": vllm_triton_gemm, "vLLM Triton": vllm_triton_gemm,
"vLLM CUTLASS": vllm_cutlass_gemm "vLLM CUTLASS": vllm_cutlass_gemm,
} }
benchmark_results = { benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
"shape": {
"m": m,
"n": n,
"k": k
},
"implementations": {}
}
for name, func in implementations.items(): for name, func in implementations.items():
# Warmup # Warmup
...@@ -140,38 +143,36 @@ def benchmark_shape(m: int, ...@@ -140,38 +143,36 @@ def benchmark_shape(m: int,
"tflops": tflops, "tflops": tflops,
"gb_s": gb_s, "gb_s": gb_s,
"diff": { "diff": {
"DeepGEMM": "DeepGEMM": 0.0
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), if name == "DeepGEMM"
"Reference": else calc_diff(func(), C_deepgemm),
deepgemm_diff if name == "DeepGEMM" else "Reference": deepgemm_diff
(vllm_triton_diff if name == "DeepGEMM"
if name == "vLLM Triton" else vllm_cutlass_diff) else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
} },
} }
if verbose: if verbose:
print( print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
)
# Calculate speedups # Calculate speedups
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
for name, data in benchmark_results["implementations"].items(): for name, data in benchmark_results["implementations"].items():
if name != "DeepGEMM": if name != "DeepGEMM":
speedup = baseline / data["time_ms"] speedup = baseline / data["time_ms"]
benchmark_results["implementations"][name][ benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
"speedup_vs_deepgemm"] = speedup
if verbose: if verbose:
print(f"DeepGEMM is {1/speedup:.2f}x " print(
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") f"DeepGEMM is {1 / speedup:.2f}x "
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
)
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
"time_ms"] vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
"time_ms"]
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
benchmark_results["implementations"]["vLLM CUTLASS"][ benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
"speedup_vs_triton"] = cutlass_vs_triton cutlass_vs_triton
)
if verbose: if verbose:
print( print(
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
...@@ -183,8 +184,7 @@ def benchmark_shape(m: int, ...@@ -183,8 +184,7 @@ def benchmark_shape(m: int,
def format_table_row(values, widths): def format_table_row(values, widths):
"""Format a row with specified column widths.""" """Format a row with specified column widths."""
return "| " + " | ".join(f"{val:{w}}" return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
for val, w in zip(values, widths)) + " |"
def print_table(headers, rows, title=None): def print_table(headers, rows, title=None):
...@@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False): ...@@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
for result in all_results: for result in all_results:
shape = result["shape"] shape = result["shape"]
impl_data = result["implementations"]["DeepGEMM"] impl_data = result["implementations"]["DeepGEMM"]
deepgemm_rows.append([ deepgemm_rows.append(
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", [
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" shape["m"],
]) shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
]
)
print_table(deepgemm_headers, print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
deepgemm_rows,
title="DeepGEMM Implementation:")
# Print vLLM Triton table # Print vLLM Triton table
triton_headers = [ triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
]
triton_rows = [] triton_rows = []
for result in all_results: for result in all_results:
shape = result["shape"] shape = result["shape"]
impl_data = result["implementations"]["vLLM Triton"] impl_data = result["implementations"]["vLLM Triton"]
speedup = impl_data.get("speedup_vs_deepgemm", 1.0) speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
triton_rows.append([ triton_rows.append(
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", [
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", shape["m"],
format_speedup(speedup) shape["n"],
]) shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(speedup),
]
)
print_table(triton_headers, print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
triton_rows,
title="vLLM Triton Implementation:")
# Print vLLM CUTLASS table # Print vLLM CUTLASS table
cutlass_headers = [ cutlass_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", "m",
"vs Triton" "n",
"k",
"Time (μs)",
"TFLOPS",
"GB/s",
"vs DeepGEMM",
"vs Triton",
] ]
cutlass_rows = [] cutlass_rows = []
for result in all_results: for result in all_results:
...@@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False): ...@@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
impl_data = result["implementations"]["vLLM CUTLASS"] impl_data = result["implementations"]["vLLM CUTLASS"]
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
vs_triton = impl_data.get("speedup_vs_triton", 1.0) vs_triton = impl_data.get("speedup_vs_triton", 1.0)
cutlass_rows.append([ cutlass_rows.append(
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", [
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", shape["m"],
format_speedup(vs_deepgemm), shape["n"],
format_speedup(vs_triton) shape["k"],
]) f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton),
]
)
print_table(cutlass_headers, print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
cutlass_rows,
title="vLLM CUTLASS Implementation:")
# Calculate and print averages # Calculate and print averages
print("\n===== AVERAGE PERFORMANCE =====") print("\n===== AVERAGE PERFORMANCE =====")
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
avg_metrics = { avg_metrics = {
impl: { impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
"tflops": 0,
"gb_s": 0,
"time_ms": 0
}
for impl in implementations
} }
for result in all_results: for result in all_results:
...@@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False): ...@@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
avg_time = avg_metrics[impl]["time_ms"] / num_shapes avg_time = avg_metrics[impl]["time_ms"] / num_shapes
avg_rows.append([ avg_rows.append(
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
]) )
print_table(avg_headers, avg_rows) print_table(avg_headers, avg_rows)
...@@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False): ...@@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
avg_speedups = { avg_speedups = {
"DeepGEMM vs vLLM Triton": 0, "DeepGEMM vs vLLM Triton": 0,
"DeepGEMM vs vLLM CUTLASS": 0, "DeepGEMM vs vLLM CUTLASS": 0,
"vLLM CUTLASS vs vLLM Triton": 0 "vLLM CUTLASS vs vLLM Triton": 0,
} }
for result in all_results: for result in all_results:
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
"time_ms"]
avg_speedups[ avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups[ avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time vllm_triton_time / vllm_cutlass_time
avg_speedups[ )
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
print("\n===== AVERAGE SPEEDUPS =====") print("\n===== AVERAGE SPEEDUPS =====")
speedup_headers = ["Comparison", "Speedup"] speedup_headers = ["Comparison", "Speedup"]
...@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False): ...@@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
for result in all_results: for result in all_results:
for impl in implementations: for impl in implementations:
avg_diff[impl] += result["implementations"][impl]["diff"][ avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
"Reference"]
diff_headers = ["Implementation", "Avg Diff vs Reference"] diff_headers = ["Implementation", "Avg Diff vs Reference"]
diff_rows = [] diff_rows = []
......
...@@ -11,7 +11,7 @@ import regex as re ...@@ -11,7 +11,7 @@ import regex as re
import seaborn as sns import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement from torch.utils.benchmark import Measurement as TMeasurement
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
from collections.abc import Iterable from collections.abc import Callable, Iterable
from typing import Any, Callable, Optional from typing import Any
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
...@@ -55,7 +55,7 @@ class Bench: ...@@ -55,7 +55,7 @@ class Bench:
def __init__( def __init__(
self, self,
cuda_graph_params: Optional[CudaGraphBenchParams], cuda_graph_params: CudaGraphBenchParams | None,
label: str, label: str,
sub_label: str, sub_label: str,
description: str, description: str,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from statistics import mean from statistics import mean
from typing import Any, NamedTuple, Optional, Union from typing import Any, NamedTuple
import numpy as np # type: ignore import numpy as np # type: ignore
import pandas as pd # type: ignore import pandas as pd # type: ignore
...@@ -11,6 +11,7 @@ from bench_utils import ( ...@@ -11,6 +11,7 @@ from bench_utils import (
Color, Color,
logger, logger,
) )
from tqdm import tqdm
from transformers import AutoTokenizer # type: ignore from transformers import AutoTokenizer # type: ignore
# Conversation ID is a string (e.g: "UzTK34D") # Conversation ID is a string (e.g: "UzTK34D")
...@@ -35,8 +36,8 @@ class Distribution(ABC): ...@@ -35,8 +36,8 @@ class Distribution(ABC):
class UniformDistribution(Distribution): class UniformDistribution(Distribution):
def __init__( def __init__(
self, self,
min_val: Union[int, float], min_val: int | float,
max_val: Union[int, float], max_val: int | float,
is_integer: bool = True, is_integer: bool = True,
) -> None: ) -> None:
self.min_val = min_val self.min_val = min_val
...@@ -56,7 +57,7 @@ class UniformDistribution(Distribution): ...@@ -56,7 +57,7 @@ class UniformDistribution(Distribution):
class ConstantDistribution(Distribution): class ConstantDistribution(Distribution):
def __init__(self, value: Union[int, float]) -> None: def __init__(self, value: int | float) -> None:
self.value = value self.value = value
self.max_val = value self.max_val = value
...@@ -68,7 +69,7 @@ class ConstantDistribution(Distribution): ...@@ -68,7 +69,7 @@ class ConstantDistribution(Distribution):
class ZipfDistribution(Distribution): class ZipfDistribution(Distribution):
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: def __init__(self, alpha: float, max_val: int | None = None) -> None:
self.alpha = alpha self.alpha = alpha
self.max_val = max_val self.max_val = max_val
...@@ -83,7 +84,7 @@ class ZipfDistribution(Distribution): ...@@ -83,7 +84,7 @@ class ZipfDistribution(Distribution):
class PoissonDistribution(Distribution): class PoissonDistribution(Distribution):
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: def __init__(self, alpha: float, max_val: int | None = None) -> None:
self.alpha = alpha self.alpha = alpha
self.max_val = max_val self.max_val = max_val
...@@ -100,11 +101,11 @@ class PoissonDistribution(Distribution): ...@@ -100,11 +101,11 @@ class PoissonDistribution(Distribution):
class LognormalDistribution(Distribution): class LognormalDistribution(Distribution):
def __init__( def __init__(
self, self,
mean: Optional[float] = None, mean: float | None = None,
sigma: Optional[float] = None, sigma: float | None = None,
average: Optional[int] = None, average: int | None = None,
median_ratio: Optional[float] = None, median_ratio: float | None = None,
max_val: Optional[int] = None, max_val: int | None = None,
) -> None: ) -> None:
self.average = average self.average = average
self.median_ratio = median_ratio self.median_ratio = median_ratio
...@@ -417,6 +418,10 @@ def generate_conversations( ...@@ -417,6 +418,10 @@ def generate_conversations(
data = file.read() data = file.read()
tokens_in_file = tokenizer.encode(data, add_special_tokens=False) tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
list_of_tokens.extend(tokens_in_file) list_of_tokens.extend(tokens_in_file)
logger.info(
f"Loaded {len(tokens_in_file)} tokens from file {filename}, "
f"total tokens so far: {len(list_of_tokens)}"
)
conversations: ConversationsMap = {} conversations: ConversationsMap = {}
conv_id = 0 conv_id = 0
...@@ -449,18 +454,25 @@ def generate_conversations( ...@@ -449,18 +454,25 @@ def generate_conversations(
) )
base_offset += common_prefix_tokens base_offset += common_prefix_tokens
for conv_id in range(args.num_conversations): for conv_id in tqdm(
range(args.num_conversations),
total=args.num_conversations,
desc="Generating conversations",
unit="conv",
):
# Generate a single conversation # Generate a single conversation
messages: MessagesList = [] messages: MessagesList = []
nturns = turn_count[conv_id] nturns = turn_count[conv_id]
# User prompt token count per turn (with lower limit) # User prompt token count per turn (with lower limit)
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int)
input_token_count = np.maximum(input_token_count, base_prompt_token_count) input_token_count = np.maximum(input_token_count, base_prompt_token_count)
# Assistant answer token count per turn (with lower limit) # Assistant answer token count per turn (with lower limit)
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype(
int
)
output_token_count = np.maximum(output_token_count, 1) output_token_count = np.maximum(output_token_count, 1)
user_turn = True user_turn = True
......
...@@ -13,7 +13,7 @@ from datetime import datetime ...@@ -13,7 +13,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from statistics import mean from statistics import mean
from typing import NamedTuple, Optional, Union from typing import NamedTuple
import aiohttp # type: ignore import aiohttp # type: ignore
import numpy as np # type: ignore import numpy as np # type: ignore
...@@ -46,15 +46,16 @@ class ConversationSampling(str, Enum): ...@@ -46,15 +46,16 @@ class ConversationSampling(str, Enum):
class ClientArgs(NamedTuple): class ClientArgs(NamedTuple):
seed: int seed: int
max_num_requests: Optional[int] max_num_requests: int | None
skip_first_turn: bool skip_first_turn: bool
max_turns: Optional[int] max_turns: int | None
max_active_conversations: int max_active_conversations: int
verbose: bool verbose: bool
print_content: bool print_content: bool
verify_output: bool verify_output: bool
conversation_sampling: ConversationSampling conversation_sampling: ConversationSampling
request_rate: float request_rate: float
max_retries: int
class RequestArgs(NamedTuple): class RequestArgs(NamedTuple):
...@@ -63,6 +64,7 @@ class RequestArgs(NamedTuple): ...@@ -63,6 +64,7 @@ class RequestArgs(NamedTuple):
stream: bool stream: bool
limit_min_tokens: int # Use negative value for no limit limit_min_tokens: int # Use negative value for no limit
limit_max_tokens: int # Use negative value for no limit limit_max_tokens: int # Use negative value for no limit
timeout_sec: int
class BenchmarkArgs(NamedTuple): class BenchmarkArgs(NamedTuple):
...@@ -109,9 +111,9 @@ class RequestStats(NamedTuple): ...@@ -109,9 +111,9 @@ class RequestStats(NamedTuple):
class MetricStats: class MetricStats:
def __init__(self) -> None: def __init__(self) -> None:
self.min: Optional[float] = None self.min: float | None = None
self.max: Optional[float] = None self.max: float | None = None
self.avg: Optional[float] = None self.avg: float | None = None
self.sum = 0.0 self.sum = 0.0
self.count = 0 self.count = 0
...@@ -143,7 +145,7 @@ class MovingAverage: ...@@ -143,7 +145,7 @@ class MovingAverage:
self.index = 0 self.index = 0
self.sum = 0.0 self.sum = 0.0
self.count = 0 self.count = 0
self.avg: Optional[float] = None self.avg: float | None = None
def update(self, new_value: float) -> None: def update(self, new_value: float) -> None:
if self.count < self.window_size: if self.count < self.window_size:
...@@ -169,7 +171,7 @@ class MovingAverage: ...@@ -169,7 +171,7 @@ class MovingAverage:
class DebugStats: class DebugStats:
def __init__(self, logger: logging.Logger, window_size: int) -> None: def __init__(self, logger: logging.Logger, window_size: int) -> None:
self.logger = logger self.logger = logger
self.metrics: dict[str, Union[MovingAverage, MetricStats]] = { self.metrics: dict[str, MovingAverage | MetricStats] = {
"moving_avg_ttft_ms": MovingAverage(window_size), "moving_avg_ttft_ms": MovingAverage(window_size),
"moving_avg_tpot_ms": MovingAverage(window_size), "moving_avg_tpot_ms": MovingAverage(window_size),
"ttft_ms": MetricStats(), "ttft_ms": MetricStats(),
...@@ -198,14 +200,6 @@ class DebugStats: ...@@ -198,14 +200,6 @@ class DebugStats:
self.logger.info("-" * 50) self.logger.info("-" * 50)
# Must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :]
return text
def nanosec_to_millisec(value: float) -> float: def nanosec_to_millisec(value: float) -> float:
return value / 1000000.0 return value / 1000000.0
...@@ -220,8 +214,9 @@ async def send_request( ...@@ -220,8 +214,9 @@ async def send_request(
chat_url: str, chat_url: str,
model: str, model: str,
stream: bool = True, stream: bool = True,
min_tokens: Optional[int] = None, min_tokens: int | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
timeout_sec: int = 120,
) -> ServerResponse: ) -> ServerResponse:
payload = { payload = {
"model": model, "model": model,
...@@ -243,16 +238,22 @@ async def send_request( ...@@ -243,16 +238,22 @@ async def send_request(
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
# Calculate the timeout for the request # Calculate the timeout for the request
timeout_sec = 120
if max_tokens is not None: if max_tokens is not None:
# Assume TPOT of 200ms and use max_tokens to determine timeout # Assume TPOT of 200ms and use max_tokens to determine timeout
timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) token_based_timeout = int(max_tokens * 0.2)
if token_based_timeout > timeout_sec:
timeout_sec = token_based_timeout
logger.info(
"Using timeout of %ds based on max_tokens %d",
timeout_sec,
max_tokens,
)
timeout = aiohttp.ClientTimeout(total=timeout_sec) timeout = aiohttp.ClientTimeout(total=timeout_sec)
valid_response = True valid_response = True
ttft: Optional[float] = None ttft: float | None = None
chunk_delay: list[int] = [] chunk_delay: list[int] = []
latency: Optional[float] = None latency: float | None = None
first_chunk = "" first_chunk = ""
generated_text = "" generated_text = ""
...@@ -269,7 +270,7 @@ async def send_request( ...@@ -269,7 +270,7 @@ async def send_request(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
# End of stream # End of stream
latency = time.perf_counter_ns() - start_time latency = time.perf_counter_ns() - start_time
...@@ -364,7 +365,7 @@ async def send_turn( ...@@ -364,7 +365,7 @@ async def send_turn(
req_args: RequestArgs, req_args: RequestArgs,
verbose: bool, verbose: bool,
verify_output: bool, verify_output: bool,
) -> Optional[RequestStats]: ) -> RequestStats | None:
assert messages_to_use > 0 assert messages_to_use > 0
assert messages_to_use <= len(conversation_messages) assert messages_to_use <= len(conversation_messages)
...@@ -417,6 +418,7 @@ async def send_turn( ...@@ -417,6 +418,7 @@ async def send_turn(
req_args.stream, req_args.stream,
min_tokens, min_tokens,
max_tokens, max_tokens,
req_args.timeout_sec,
) )
if response.valid is False: if response.valid is False:
...@@ -526,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None: ...@@ -526,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None:
await asyncio.sleep(interval) await asyncio.sleep(interval)
async def exponential_backoff_sleep(
attempt_cnt: int,
base_rate: float = 1.0,
backoff_factor: float = 2.0,
jitter_fraction: float = 0.10,
verbose: bool = False,
) -> None:
# Sleep with exponential backoff and jitter after a failed request.
backoff_delay = base_rate * (backoff_factor**attempt_cnt)
jittered_delay = backoff_delay * (
1 + np.random.uniform(-jitter_fraction, jitter_fraction)
)
if verbose:
logger.info(f"Backoff for {jittered_delay:.3f} seconds...")
await asyncio.sleep(jittered_delay)
async def client_main( async def client_main(
args: ClientArgs, args: ClientArgs,
req_args: RequestArgs, req_args: RequestArgs,
...@@ -540,8 +561,11 @@ async def client_main( ...@@ -540,8 +561,11 @@ async def client_main(
f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501 f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501
) )
random.seed(args.seed) # Set unique seed per client (each client runs in its own process)
np.random.seed(args.seed) # Add 1 to ensure no client uses the same seed as the main process
client_seed = args.seed + client_id + 1
random.seed(client_seed)
np.random.seed(client_seed)
# Active conversations # Active conversations
active_convs: ConversationsMap = {} active_convs: ConversationsMap = {}
...@@ -644,7 +668,7 @@ async def client_main( ...@@ -644,7 +668,7 @@ async def client_main(
if args.verbose: if args.verbose:
curr_time_sec: float = time.perf_counter() curr_time_sec: float = time.perf_counter()
time_since_last_turn: Union[str, float] = "N/A" time_since_last_turn: str | float = "N/A"
if conv_id in time_of_last_turn: if conv_id in time_of_last_turn:
time_since_last_turn = round( time_since_last_turn = round(
curr_time_sec - time_of_last_turn[conv_id], 3 curr_time_sec - time_of_last_turn[conv_id], 3
...@@ -654,49 +678,62 @@ async def client_main( ...@@ -654,49 +678,62 @@ async def client_main(
) )
time_of_last_turn[conv_id] = curr_time_sec time_of_last_turn[conv_id] = curr_time_sec
success = True success = False
try: for attempt_cnt in range(args.max_retries + 1):
result = await send_turn( try:
session, exception = False
client_id, result = await send_turn(
conv_id, session,
messages, client_id,
current_turn, conv_id,
tokenizer, messages,
req_args, current_turn,
args.print_content, tokenizer,
args.verify_output, req_args,
) args.print_content,
if result is not None: args.verify_output,
result_queue.put(result) )
else: if result is not None:
# None means that the request failed, result_queue.put(result)
# and should not be added to the statistics. success = True
success = False break
num_failures += 1 else:
logger.warning(
logger.warning( f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 )
except asyncio.exceptions.TimeoutError:
exception = True
logger.error(
"%sClient %d - Timeout during conversation ID %s (turn: %d). "
"Base timeout is %ss (set with --request-timeout-sec), but the "
"effective timeout may be longer based on max_tokens. If this "
"is unexpected, consider increasing the timeout or checking "
"model performance.%s",
Color.RED,
client_id,
conv_id,
current_turn,
req_args.timeout_sec,
Color.RESET,
)
except Exception:
exception = True
logger.exception(
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
) )
# Remove the conversation (should not be used again) # Sleep before retry if not last attempt
active_convs.pop(conv_id) if not success and attempt_cnt < args.max_retries:
await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose)
except asyncio.exceptions.TimeoutError:
num_failures += 1
logger.exception(
f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
)
break # Exit gracefully instead of raising an error
except Exception: if not success:
num_failures += 1 num_failures += 1
logger.exception( # Remove the conversation (should not be used again)
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 active_convs.pop(conv_id)
) if exception:
break # Exit gracefully instead of raising an error break # Exit gracefully instead of raising an error
if success: else:
num_successes += 1 num_successes += 1
# Update the turns counter to include the LLM response # Update the turns counter to include the LLM response
...@@ -769,7 +806,7 @@ def get_client_config( ...@@ -769,7 +806,7 @@ def get_client_config(
"Number of conversations must be equal or larger than the number of clients" "Number of conversations must be equal or larger than the number of clients"
) )
max_req_per_client: Optional[int] = None max_req_per_client: int | None = None
if args.max_num_requests is not None: if args.max_num_requests is not None:
# Max number of requests per client # Max number of requests per client
req_per_client = args.max_num_requests // args.num_clients req_per_client = args.max_num_requests // args.num_clients
...@@ -811,6 +848,7 @@ def get_client_config( ...@@ -811,6 +848,7 @@ def get_client_config(
verify_output=args.verify_output, verify_output=args.verify_output,
conversation_sampling=args.conversation_sampling, conversation_sampling=args.conversation_sampling,
request_rate=args.request_rate, request_rate=args.request_rate,
max_retries=args.max_retries,
) )
if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: if args.limit_min_tokens > 0 or args.limit_max_tokens > 0:
...@@ -823,6 +861,9 @@ def get_client_config( ...@@ -823,6 +861,9 @@ def get_client_config(
"Invalid min/max tokens limits (min should not be larger than max)" "Invalid min/max tokens limits (min should not be larger than max)"
) )
if args.request_timeout_sec <= 0:
raise ValueError("Request timeout must be a positive number")
# Arguments for API requests # Arguments for API requests
chat_url = f"{args.url}/v1/chat/completions" chat_url = f"{args.url}/v1/chat/completions"
model_name = args.served_model_name if args.served_model_name else args.model model_name = args.served_model_name if args.served_model_name else args.model
...@@ -833,6 +874,7 @@ def get_client_config( ...@@ -833,6 +874,7 @@ def get_client_config(
stream=not args.no_stream, stream=not args.no_stream,
limit_min_tokens=args.limit_min_tokens, limit_min_tokens=args.limit_min_tokens,
limit_max_tokens=args.limit_max_tokens, limit_max_tokens=args.limit_max_tokens,
timeout_sec=args.request_timeout_sec,
) )
return client_args, req_args return client_args, req_args
...@@ -936,13 +978,13 @@ async def main_mp( ...@@ -936,13 +978,13 @@ async def main_mp(
f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501
) )
rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3) rps: str | float = round(len(client_metrics) / runtime_sec, 3)
if len(client_metrics) < (5 * bench_args.num_clients): if len(client_metrics) < (5 * bench_args.num_clients):
# Do not estimate the RPS if the number of samples is very low # Do not estimate the RPS if the number of samples is very low
# (threshold can be tuned if needed) # (threshold can be tuned if needed)
rps = "N/A" rps = "N/A"
runtime_left_sec: Union[str, float] = round( runtime_left_sec: str | float = round(
(runtime_sec / finished_convs) * (total_convs - finished_convs), 3 (runtime_sec / finished_convs) * (total_convs - finished_convs), 3
) )
if percent < 0.05: if percent < 0.05:
...@@ -976,7 +1018,7 @@ async def main_mp( ...@@ -976,7 +1018,7 @@ async def main_mp(
f"(is alive: {client.is_alive()}){Color.RESET}" f"(is alive: {client.is_alive()}){Color.RESET}"
) )
client.join(timeout=120) client.join(timeout=req_args.timeout_sec + 1)
if client.is_alive(): if client.is_alive():
logger.warning( logger.warning(
...@@ -1032,7 +1074,7 @@ def process_statistics( ...@@ -1032,7 +1074,7 @@ def process_statistics(
warmup_percentages: list[float], warmup_percentages: list[float],
test_params: dict, test_params: dict,
verbose: bool, verbose: bool,
gen_conv_args: Optional[GenConvArgs] = None, gen_conv_args: GenConvArgs | None = None,
excel_output: bool = False, excel_output: bool = False,
) -> None: ) -> None:
if len(client_metrics) == 0: if len(client_metrics) == 0:
...@@ -1259,7 +1301,7 @@ async def main() -> None: ...@@ -1259,7 +1301,7 @@ async def main() -> None:
default=None, default=None,
help="The model name used in the API. " help="The model name used in the API. "
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the ``--model`` argument. ", "same as the `--model` argument. ",
) )
parser.add_argument( parser.add_argument(
...@@ -1342,6 +1384,16 @@ async def main() -> None: ...@@ -1342,6 +1384,16 @@ async def main() -> None:
help="Expected request rate (Poisson process) per client in requests/sec." help="Expected request rate (Poisson process) per client in requests/sec."
"Set to 0 for no delay between requests.", "Set to 0 for no delay between requests.",
) )
parser.add_argument(
"--max-retries",
type=int,
default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")),
help="Maximum number of retry attempts for timed-out requests. "
"Default is 0 (no retries). "
"Set to higher values to retry failed requests and maintain "
"fair workload distribution. "
"Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.",
)
parser.add_argument( parser.add_argument(
"--conversation-sampling", "--conversation-sampling",
type=ConversationSampling, type=ConversationSampling,
...@@ -1359,6 +1411,13 @@ async def main() -> None: ...@@ -1359,6 +1411,13 @@ async def main() -> None:
action="store_true", action="store_true",
help="Verify the LLM output (compare to the answers in the input JSON file)", help="Verify the LLM output (compare to the answers in the input JSON file)",
) )
parser.add_argument(
"--request-timeout-sec",
type=int,
default=120,
help="Timeout in seconds for each API request (default: 120). "
"Automatically increased if max tokens imply longer decoding.",
)
parser.add_argument( parser.add_argument(
"--no-stream", "--no-stream",
...@@ -1434,11 +1493,10 @@ async def main() -> None: ...@@ -1434,11 +1493,10 @@ async def main() -> None:
f"Invalid --warmup-percentage={args.warmup_percentage}" f"Invalid --warmup-percentage={args.warmup_percentage}"
) from None ) from None
# Set global seeds for main process
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
if not os.path.exists(args.model):
raise OSError(f"Path does not exist: {args.model}")
logger.info("Loading tokenizer") logger.info("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
......
...@@ -13,7 +13,7 @@ import argparse ...@@ -13,7 +13,7 @@ import argparse
import json import json
import random import random
from statistics import mean from statistics import mean
from typing import Any, Optional from typing import Any
import pandas as pd # type: ignore import pandas as pd # type: ignore
import tqdm # type: ignore import tqdm # type: ignore
...@@ -25,7 +25,7 @@ def has_non_english_chars(text: str) -> bool: ...@@ -25,7 +25,7 @@ def has_non_english_chars(text: str) -> bool:
def content_is_valid( def content_is_valid(
content: str, min_content_len: Optional[int], max_content_len: Optional[int] content: str, min_content_len: int | None, max_content_len: int | None
) -> bool: ) -> bool:
if min_content_len and len(content) < min_content_len: if min_content_len and len(content) < min_content_len:
return False return False
...@@ -37,7 +37,7 @@ def content_is_valid( ...@@ -37,7 +37,7 @@ def content_is_valid(
def print_stats( def print_stats(
conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None conversations: "list[dict[Any, Any]]", tokenizer: AutoTokenizer | None = None
) -> None: ) -> None:
# Collect statistics # Collect statistics
stats = [] stats = []
...@@ -109,12 +109,12 @@ def convert_sharegpt_to_openai( ...@@ -109,12 +109,12 @@ def convert_sharegpt_to_openai(
seed: int, seed: int,
input_file: str, input_file: str,
output_file: str, output_file: str,
max_items: Optional[int], max_items: int | None,
min_content_len: Optional[int] = None, min_content_len: int | None = None,
max_content_len: Optional[int] = None, max_content_len: int | None = None,
min_turns: Optional[int] = None, min_turns: int | None = None,
max_turns: Optional[int] = None, max_turns: int | None = None,
model: Optional[str] = None, model: str | None = None,
) -> None: ) -> None:
if min_turns and max_turns: if min_turns and max_turns:
assert min_turns <= max_turns assert min_turns <= max_turns
......
...@@ -2,4 +2,5 @@ numpy>=1.24 ...@@ -2,4 +2,5 @@ numpy>=1.24
pandas>=2.0.0 pandas>=2.0.0
aiohttp>=3.10 aiohttp>=3.10
transformers>=4.46 transformers>=4.46
xlsxwriter>=3.2.1 xlsxwriter>=3.2.1
\ No newline at end of file tqdm>=4.66
...@@ -5,7 +5,7 @@ import cProfile ...@@ -5,7 +5,7 @@ import cProfile
import pstats import pstats
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k. # A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000 LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
......
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true
\ No newline at end of file
...@@ -15,6 +15,7 @@ endif() ...@@ -15,6 +15,7 @@ endif()
# #
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc") include_directories("${CMAKE_SOURCE_DIR}/csrc")
...@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ...@@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(ENABLE_AVX512VNNI OFF) set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif() endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND) elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2") list(APPEND CXX_COMPILE_FLAGS "-mavx2")
...@@ -188,31 +205,115 @@ else() ...@@ -188,31 +205,115 @@ else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
endif() endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
if (VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
FetchContent_Declare( # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
oneDNN # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
GIT_TAG v3.9 if(ASIMD_FOUND)
GIT_PROGRESS TRUE # Set number of parallel build processes
GIT_SHALLOW TRUE include(ProcessorCount)
) ProcessorCount(NPROC)
if(NOT NPROC)
set(NPROC 4)
endif()
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
# and create a local shim dir with it
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
find_library(OPEN_MP
NAMES gomp
PATHS ${VLLM_TORCH_GOMP_SHIM_DIR}
NO_DEFAULT_PATH
REQUIRED
)
# Set LD_LIBRARY_PATH to include the shim dir at build time to use the same libgomp as PyTorch
if (OPEN_MP)
set(ENV{LD_LIBRARY_PATH} "${VLLM_TORCH_GOMP_SHIM_DIR}:$ENV{LD_LIBRARY_PATH}")
endif()
# Fetch and populate ACL
if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}")
message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}")
else()
message(STATUS "Downloading Arm Compute Library (ACL) from GitHub")
FetchContent_Populate(arm_compute
SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild"
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src"
GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git
GIT_TAG v52.6.0
GIT_SHALLOW TRUE
GIT_PROGRESS TRUE
)
set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}")
set(ACL_LIB_DIR "$ENV{ACL_ROOT_DIR}/build")
endif()
if(USE_ACL) # Build ACL with CMake
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
if(NOT ARM_COMPUTE_LIBRARY) set(CMAKE_BUILD_TYPE "Release")
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR") set(ARM_COMPUTE_ARCH "armv8.2-a")
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
set(ARM_COMPUTE_BUILD_TESTING "OFF")
set(_cmake_config_cmd
${CMAKE_COMMAND} -G Ninja -B build
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
-DCMAKE_BUILD_TYPE=Release
-DARM_COMPUTE_ARCH=armv8.2-a
-DARM_COMPUTE_ENABLE_ASSERTS=OFF
-DARM_COMPUTE_ENABLE_CPPTHREADS=OFF
-DARM_COMPUTE_ENABLE_OPENMP=ON
-DARM_COMPUTE_ENABLE_WERROR=OFF
-DARM_COMPUTE_BUILD_EXAMPLES=OFF
-DARM_COMPUTE_BUILD_TESTING=OFF)
set(_cmake_build_cmd
${CMAKE_COMMAND} --build build -- -j${NPROC}
)
execute_process(
COMMAND ${_cmake_config_cmd}
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
)
execute_process(
COMMAND ${_cmake_build_cmd}
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
RESULT_VARIABLE _acl_rc
)
if(NOT _acl_rc EQUAL 0)
message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).")
endif() endif()
set(ONEDNN_AARCH64_USE_ACL "ON") message(STATUS "Arm Compute Library (ACL) built successfully.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
# VLLM/oneDNN settings for ACL
set(ONEDNN_AARCH64_USE_ACL ON CACHE BOOL "" FORCE)
add_compile_definitions(VLLM_USE_ACL)
endif()
set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.")
if(FETCHCONTENT_SOURCE_DIR_ONEDNN)
message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}")
FetchContent_Declare(
oneDNN
SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN}
)
else()
message(STATUS "Downloading oneDNN from GitHub")
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.10
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
endif() endif()
set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_LIBRARY_TYPE "STATIC")
...@@ -229,7 +330,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -229,7 +330,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_VERBOSE "OFF") set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE})
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
target_include_directories( target_include_directories(
dnnl_ext dnnl_ext
...@@ -259,14 +363,14 @@ endif() ...@@ -259,14 +363,14 @@ endif()
# #
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp" "csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp" "csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp" "csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp" "csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp" "csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp") "csrc/cpu/cpu_attn.cpp"
"csrc/cpu/scratchpad_manager.cpp"
"csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
...@@ -297,7 +401,7 @@ message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") ...@@ -297,7 +401,7 @@ message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
# Define extension targets # Define extension targets
# #
define_gpu_extension_target( define_extension_target(
_C _C
DESTINATION vllm DESTINATION vllm
LANGUAGE CXX LANGUAGE CXX
......
...@@ -19,7 +19,7 @@ else() ...@@ -19,7 +19,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
flashmla flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f GIT_TAG 46d64a8ebef03fa50b4ae74937276a5c940e3f95
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
...@@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS) ...@@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS)
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
) )
set(FlashMLA_INCLUDES set(FlashMLA_INCLUDES
...@@ -91,7 +92,7 @@ if(FLASH_MLA_ARCHS) ...@@ -91,7 +92,7 @@ if(FLASH_MLA_ARCHS)
SRCS "${FlashMLA_Extension_SOURCES}" SRCS "${FlashMLA_Extension_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}") CUDA_ARCHS "${FLASH_MLA_ARCHS}")
define_gpu_extension_target( define_extension_target(
_flashmla_C _flashmla_C
DESTINATION vllm DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG} LANGUAGE ${VLLM_GPU_LANG}
...@@ -108,7 +109,7 @@ if(FLASH_MLA_ARCHS) ...@@ -108,7 +109,7 @@ if(FLASH_MLA_ARCHS)
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API> $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>) $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
define_gpu_extension_target( define_extension_target(
_flashmla_extension_C _flashmla_extension_C
DESTINATION vllm DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG} LANGUAGE ${VLLM_GPU_LANG}
......
include(FetchContent)
set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory")
if(DEFINED ENV{QUTLASS_SRC_DIR})
set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR})
endif()
if(QUTLASS_SRC_DIR)
FetchContent_Declare(
qutlass
SOURCE_DIR ${QUTLASS_SRC_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
else()
FetchContent_Declare(
qutlass
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
endif()
FetchContent_Populate(qutlass)
if(NOT qutlass_SOURCE_DIR)
message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.")
endif()
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
if(QUTLASS_ARCHS MATCHES "10\\.0a")
set(QUTLASS_TARGET_CC 100)
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
set(QUTLASS_TARGET_CC 120)
else()
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
endif()
set(QUTLASS_SOURCES
${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu
)
set(QUTLASS_INCLUDES
${qutlass_SOURCE_DIR}
${qutlass_SOURCE_DIR}/qutlass
${qutlass_SOURCE_DIR}/qutlass/csrc/include
${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions
)
if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h")
list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}")
elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h")
list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include")
message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).")
else()
message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. "
"Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include")
endif()
set_gencode_flags_for_srcs(
SRCS "${QUTLASS_SOURCES}"
CUDA_ARCHS "${QUTLASS_ARCHS}"
)
target_sources(_C PRIVATE ${QUTLASS_SOURCES})
target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES})
target_compile_definitions(_C PRIVATE
QUTLASS_DISABLE_PYBIND=1
TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC}
)
set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --use_fast_math -O3>
)
else()
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
message(STATUS
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
else()
message(STATUS
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
"CUDA_ARCHS='${CUDA_ARCHS}'.")
endif()
endif()
...@@ -38,7 +38,7 @@ else() ...@@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a GIT_TAG 58e0626a692f09241182582659e3bf8f16472659
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
...@@ -16,7 +16,7 @@ import shutil ...@@ -16,7 +16,7 @@ import shutil
from torch.utils.hipify.hipify_python import hipify from torch.utils.hipify.hipify_python import hipify
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Project directory where all the source + include files live. # Project directory where all the source + include files live.
...@@ -34,15 +34,14 @@ if __name__ == '__main__': ...@@ -34,15 +34,14 @@ if __name__ == '__main__':
) )
# Source files to convert. # Source files to convert.
parser.add_argument("sources", parser.add_argument(
help="Source files to hipify.", "sources", help="Source files to hipify.", nargs="*", default=[]
nargs="*", )
default=[])
args = parser.parse_args() args = parser.parse_args()
# Limit include scope to project_dir only # Limit include scope to project_dir only
includes = [os.path.join(args.project_dir, '*')] includes = [os.path.join(args.project_dir, "*")]
# Get absolute path for all source files. # Get absolute path for all source files.
extra_files = [os.path.abspath(s) for s in args.sources] extra_files = [os.path.abspath(s) for s in args.sources]
...@@ -51,25 +50,31 @@ if __name__ == '__main__': ...@@ -51,25 +50,31 @@ if __name__ == '__main__':
# The directory might already exist to hold object files so we ignore that. # The directory might already exist to hold object files so we ignore that.
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
hipify_result = hipify(project_directory=args.project_dir, hipify_result = hipify(
output_directory=args.output_dir, project_directory=args.project_dir,
header_include_dirs=[], output_directory=args.output_dir,
includes=includes, header_include_dirs=[],
extra_files=extra_files, includes=includes,
show_detailed=True, extra_files=extra_files,
is_pytorch_extension=True, show_detailed=True,
hipify_extra_files_only=True) is_pytorch_extension=True,
hipify_extra_files_only=True,
)
hipified_sources = [] hipified_sources = []
for source in args.sources: for source in args.sources:
s_abs = os.path.abspath(source) s_abs = os.path.abspath(source)
hipified_s_abs = (hipify_result[s_abs].hipified_path if hipified_s_abs = (
(s_abs in hipify_result hipify_result[s_abs].hipified_path
and hipify_result[s_abs].hipified_path is not None) if (
else s_abs) s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None
)
else s_abs
)
hipified_sources.append(hipified_s_abs) hipified_sources.append(hipified_s_abs)
assert (len(hipified_sources) == len(args.sources)) assert len(hipified_sources) == len(args.sources)
# Print hipified source files. # Print hipified source files.
print("\n".join(hipified_sources)) print("\n".join(hipified_sources))
...@@ -130,6 +130,44 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -130,6 +130,44 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction() endfunction()
# Find libgomp that gets shipped with PyTorch wheel and create a shim dir with:
# libgomp.so -> libgomp-<hash>.so...
# libgomp.so.1 -> libgomp-<hash>.so...
# OUTPUT: TORCH_GOMP_SHIM_DIR ("" if not found)
function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
set(${TORCH_GOMP_SHIM_DIR} "" PARENT_SCOPE)
# Use run_python to locate vendored libgomp; never throw on failure.
run_python(_VLLM_TORCH_GOMP_PATH
"
import os, glob
try:
import torch
torch_pkg = os.path.dirname(torch.__file__)
site_root = os.path.dirname(torch_pkg)
torch_libs = os.path.join(site_root, 'torch.libs')
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
except:
print('')
"
"failed to probe torch.libs for libgomp")
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
return()
endif()
# Create shim under the build tree
set(_shim "${CMAKE_BINARY_DIR}/gomp_shim")
file(MAKE_DIRECTORY "${_shim}")
execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so")
execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so.1")
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so")
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so.1")
set(${TORCH_GOMP_SHIM_DIR} "${_shim}" PARENT_SCOPE)
endfunction()
# Macro for converting a `gencode` version number to a cmake version number. # Macro for converting a `gencode` version number to a cmake version number.
macro(string_to_ver OUT_VER IN_STR) macro(string_to_ver OUT_VER IN_STR)
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
...@@ -311,13 +349,13 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR ...@@ -311,13 +349,13 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
list(REMOVE_DUPLICATES _PTX_ARCHS) list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS # remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS
set(_CUDA_ARCHS) set(_CUDA_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS}) foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "\\a$") if(_arch MATCHES "[af]$")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
string(REPLACE "a" "" _base "${_arch}") string(REGEX REPLACE "[af]$" "" _base "${_arch}")
if ("${_base}" IN_LIST TGT_CUDA_ARCHS) if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
list(APPEND _CUDA_ARCHS "${_arch}") list(APPEND _CUDA_ARCHS "${_arch}")
...@@ -416,21 +454,20 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) ...@@ -416,21 +454,20 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
endmacro() endmacro()
# #
# Define a target named `GPU_MOD_NAME` for a single extension. The # Define a target named `MOD_NAME` for a single extension. The
# arguments are: # arguments are:
# #
# DESTINATION <dest> - Module destination directory. # DESTINATION <dest> - Module destination directory.
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP, # LANGUAGE <lang> - The language for this module, e.g. CUDA, HIP,
# etc. # CXX, etc.
# SOURCES <sources> - List of source files relative to CMakeLists.txt # SOURCES <sources> - List of source files relative to CMakeLists.txt
# directory. # directory.
# #
# Optional arguments: # Optional arguments:
# #
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake # ARCHITECTURES <arches> - A list of target architectures in cmake format.
# format. # For GPU, refer to CMAKE_CUDA_ARCHITECTURES and
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation # CMAKE_HIP_ARCHITECTURES for more info.
# and `CMAKE_HIP_ARCHITECTURES` for more info.
# ARCHITECTURES will use cmake's defaults if # ARCHITECTURES will use cmake's defaults if
# not provided. # not provided.
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip. # COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
...@@ -441,63 +478,61 @@ endmacro() ...@@ -441,63 +478,61 @@ endmacro()
# #
# Note: optimization level/debug info is set via cmake build type. # Note: optimization level/debug info is set via cmake build type.
# #
function (define_gpu_extension_target GPU_MOD_NAME) function (define_extension_target MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1 cmake_parse_arguments(PARSE_ARGV 1
GPU ARG
"WITH_SOABI" "WITH_SOABI"
"DESTINATION;LANGUAGE;USE_SABI" "DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm. # Add hipify preprocessing step when building with HIP/ROCm.
if (GPU_LANGUAGE STREQUAL "HIP") if (ARG_LANGUAGE STREQUAL "HIP")
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}") hipify_sources_target(ARG_SOURCES ${MOD_NAME} "${ARG_SOURCES}")
endif() endif()
if (GPU_WITH_SOABI) if (ARG_WITH_SOABI)
set(GPU_WITH_SOABI WITH_SOABI) set(SOABI_KEYWORD WITH_SOABI)
else() else()
set(GPU_WITH_SOABI) set(SOABI_KEYWORD "")
endif() endif()
if (GPU_USE_SABI) if (ARG_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}")
else() else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}")
endif() endif()
if (GPU_LANGUAGE STREQUAL "HIP") if (ARG_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step. # Make this target dependent on the hipify preprocessor step.
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) add_dependencies(${MOD_NAME} hipify${MOD_NAME})
# Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder # Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder
target_include_directories(${GPU_MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc target_include_directories(${MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc
${GPU_INCLUDE_DIRECTORIES}) ${ARG_INCLUDE_DIRECTORIES})
else() else()
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc target_include_directories(${MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES}) ${ARG_INCLUDE_DIRECTORIES})
endif() endif()
if (GPU_ARCHITECTURES) if (ARG_ARCHITECTURES)
set_target_properties(${GPU_MOD_NAME} PROPERTIES set_target_properties(${MOD_NAME} PROPERTIES
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") ${ARG_LANGUAGE}_ARCHITECTURES "${ARG_ARCHITECTURES}")
endif() endif()
target_compile_options(${MOD_NAME} PRIVATE
$<$<COMPILE_LANGUAGE:${ARG_LANGUAGE}>:${ARG_COMPILE_FLAGS}>)
target_compile_options(${GPU_MOD_NAME} PRIVATE target_compile_definitions(${MOD_NAME} PRIVATE
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>) "-DTORCH_EXTENSION_NAME=${MOD_NAME}")
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
target_link_libraries(${MOD_NAME} PRIVATE torch ${ARG_LIBRARIES})
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
# dependencies that are not necessary and may not be installed. # dependencies that are not necessary and may not be installed.
if (GPU_LANGUAGE STREQUAL "CUDA") if (ARG_LANGUAGE STREQUAL "CUDA")
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver) target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES})
else() else()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES})
endif() endif()
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME})
endfunction() endfunction()
\ No newline at end of file
codecov:
require_ci_to_pass: false
fixes:
# Map source code paths to repository root paths
# Wildcards match any Python version (python3.*)
- "/vllm-workspace/src/vllm/::vllm/"
- "/vllm-workspace/vllm/::vllm/"
- "/usr/local/lib/python3.*/dist-packages/vllm/::vllm/"
- "/usr/local/lib/python3.*/site-packages/vllm/::vllm/"
- "/usr/lib/python3.*/dist-packages/vllm/::vllm/"
- "/usr/lib/python3.*/site-packages/vllm/::vllm/"
...@@ -28,10 +28,10 @@ ...@@ -28,10 +28,10 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
......
...@@ -46,6 +46,32 @@ __global__ void merge_attn_states_kernel( ...@@ -46,6 +46,32 @@ __global__ void merge_attn_states_kernel(
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse; s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
const float max_lse = fmaxf(p_lse, s_lse); const float max_lse = fmaxf(p_lse, s_lse);
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that request’s position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
output_lse[head_idx * num_tokens + token_idx] = max_lse;
}
return;
}
p_lse = p_lse - max_lse; p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse; s_lse = s_lse - max_lse;
const float p_se = expf(p_lse); const float p_se = expf(p_lse);
......
...@@ -125,32 +125,37 @@ public: ...@@ -125,32 +125,37 @@ public:
} }
static void set_split_kv (KernelArguments& args) { static void set_split_kv (KernelArguments& args) {
// printf("set_split_kv start");
if (args.split_kv >= 1) return; if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape; auto [H, K, D, B] = args.problem_shape;
// std::cout << H << " " << K << " " << D << " " << B << "\n";
int sm_count = args.hw_info.sm_count; int sm_count = args.hw_info.sm_count;
// printf(" sm_count = %d\n", sm_count); float seq_length_k = static_cast<float>(K) / 1024.0f;
int max_splits = ceil_div(K, 128); int max_splits = 1;
max_splits = min(16, max_splits);
if (B <= 4 && seq_length_k >= 16) {
// TODO: This avoids a hang when the batch size larger than 1 and max_splits = 16;
// there is more than 1 kv_splits. }
// Discuss with NVIDIA how this can be fixed. else if (B <= 8 && seq_length_k >= 4) {
if (B > 1) { max_splits = 8;
max_splits = min(1, max_splits); }
else if ((B <= 16 && seq_length_k >= 8) ||
(B == 48 && seq_length_k >= 32)) {
max_splits = 4;
}
else if ((B <= 32 && seq_length_k >= 16) ||
(B == 96 && seq_length_k >= 16)) {
max_splits = 2;
} }
else {
// printf(" max_splits = %d\n", max_splits); max_splits = 1;
}
// Wave-aware scheduling: ensure integer number of waves in K dimension
int sms_per_batch = max(1, sm_count / B); int sms_per_batch = max(1, sm_count / B);
// printf(" sms_per_batch = %d\n", sms_per_batch);
int split_heur = min(max_splits, sms_per_batch); int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count); int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur); int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves); int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware; args.split_kv = split_wave_aware;
// printf(" args.split_kv = %d\n", args.split_kv);
} }
/// Determines whether the GEMM can execute the given problem. /// Determines whether the GEMM can execute the given problem.
......
...@@ -64,3 +64,11 @@ void indexer_k_quant_and_cache( ...@@ -64,3 +64,11 @@ void indexer_k_quant_and_cache(
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size int64_t quant_block_size, // quantization block size
const std::string& scale_fmt); const std::string& scale_fmt);
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment