Commit 909abb58 authored by maxiao's avatar maxiao
Browse files

adapt to sglang v0.5.2rc1 on dcu

parents
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse
import torch
import triton
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
quantize_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
]:
num_groups = config.n_routed_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights
c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)
def create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
"""
Create test data for FP8 grouped GEMM operations.
Args:
batch_size: Total batch size
num_groups: Number of groups
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
Returns:
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
For cutlass: (x, wq, w_scale, m_sizes)
"""
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
# Create weight matrices for each group
w_list = []
for _ in range(num_groups):
w = torch.randn(
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
w_list.append(w)
# Quantize weights using quantize_fp8_row for each group
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
if backend == "triton":
# Triton format: concatenated weights
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
# Create m_sizes as int32 for triton
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
# Create and quantize input
x_fp16 = torch.randn(
batch_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
x_scale = x_scale.view(batch_size, -1)
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
elif backend == "cutlass":
# CUTLASS format: stacked weights
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Create m_sizes as int64 for cutlass
m_values = [tokens_per_group] * num_groups
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
# Create input data - separate for each group then concat
x_list = []
for _ in range(num_groups):
x = torch.randn(
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
)
x_list.append(x)
# Concatenate inputs into single tensor
x = torch.concat(x_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes
else:
raise ValueError(f"Unsupported backend: {backend}")
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
"""
Calculate memory bandwidth based on accessed expert weights.
Args:
m_sizes: Tensor containing batch sizes for each group
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
dtype: Data type of weights
Returns:
Memory size in bytes for accessed expert weights
"""
# Count non-zero groups (active experts)
if hasattr(m_sizes, "cpu"):
active_experts = torch.count_nonzero(m_sizes).item()
else:
active_experts = sum(1 for m in m_sizes if m > 0)
# Calculate bytes per element based on dtype
if dtype in [torch.float16, torch.bfloat16]:
bytes_per_element = 2
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
bytes_per_element = 1
elif dtype == torch.float32:
bytes_per_element = 4
else:
# Default to 2 bytes for unknown dtypes
bytes_per_element = 2
# Memory per expert weight matrix
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
# Total memory for active experts
total_memory_bytes = active_experts * memory_per_expert
return total_memory_bytes
def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": [
"fbgemm_triton_grouped_gemm_fp8",
"fbgemm_cutlass_f8f8bf16_rowwise",
"sglang_grouped_gemm",
],
"line_names": [
"FBGEMM Triton Grouped GEMM FP8",
"FBGEMM CUTLASS F8F8BF16 Rowwise",
"SGLang Grouped GEMM FP8",
],
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
"line_names": [
"FBGEMM Triton Grouped GEMM BF16",
"SGLang Grouped GEMM BF16",
],
"styles": [("blue", "-"), ("green", "-")],
}
def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)
benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Bandwidth (GB/s)",
plot_name="grouped-gemm-performance",
args={},
)
@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_triton_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="triton",
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="cutlass",
)
x, wq, w_scale, m_sizes = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
# Quantize input using triton_quantize_fp8_row
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(batch_size, -1)
def run_func():
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
)
except Exception as e:
print(
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
f"skipping: {e}"
)
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
# Calculate memory bandwidth for BF16 operations
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.bfloat16
)
if provider == "fbgemm_triton_grouped_gemm":
def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)
else:
def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
torch.cuda.synchronize()
try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
# Convert time (ms) to bandwidth (GB/s)
# Bandwidth = Memory (bytes) / Time (seconds)
# Convert ms to seconds and bytes to GB (1e9)
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
# min bandwidth = max time, max bandwidth = min time
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
return gb_per_s, min_gb_per_s, max_gb_per_s
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return 0.0, 0.0, 0.0
dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)
def verify_correctness(model_config):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)
args = parser.parse_args()
try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}
print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness:
if not verify_correctness(model_config):
print("Correctness verification failed. Exiting...")
return
try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()
# FlashInfer Fused AllReduce + RMSNorm Benchmark
This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.
This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.
Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`
## Feature Overview
- Compare average execution time (ms) and calculate speedup ratios for the following paths:
- standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)
- flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes
- Optionally compare FP8/FP4 quantized fused paths with standard paths
- Use CUDA Graph capture and batch replay to reduce measurement noise
- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation
- Optionally export results in Markdown format
## Runtime Environment and Prerequisites
- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)
- Properly install/compile sglang along with sgl-kernel and custom operators
## Quick Start (Command Examples)
The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:
- Regular paths only (no quantization):
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP8 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP4 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- Larger hidden dimensions:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100
```
## Parameter Description
- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)
- `--hidden-dim`: Hidden dimension (default: 8192)
- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)
- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual")
- Mutually exclusive quantization options:
- `--no-quant`: No quantization testing
- `--quant-fp8`: Only FP8 quantization testing
- `--quant-fp4`: Only FP4 quantization testing
- `--quant-all`: Test all (default)
- FlashInfer related:
- `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)
- Runtime configuration:
- `--warmup`: Warmup count before graph capture and before graph replay (default 5)
- `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)
- `--output-file`: Save results as Markdown file (only rank0 takes effect)
## Output Example
Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:
```
================================================================================
Results: seq_len=1024, hidden_dim=1024
dtype=torch.bfloat16, residual=yes, quant_mode=none
================================================================================
Operation Time (ms) Speedup
--------------------------------------------------------------------------------
standard_allreduce_rmsnorm 0.024 0.98x
standard_allreduce_rmsnorm_native_compiled 0.023 baseline
flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x
flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x
```
If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.
## Important Notes and Recommendations
- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.
- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.
- FlashInfer:
- If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.
- The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.
- FP8/FP4:
- FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.
- FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.
- CUDA Graph:
- Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.
# Modified from https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
"""
import argparse
import contextlib
import itertools
import logging
import os
import time
from typing import Optional
import torch # type: ignore
import torch.distributed as dist # type: ignore
from sglang.srt.distributed import get_tp_group, tensor_model_parallel_all_reduce
from sglang.srt.distributed.parallel_state import (
cleanup_dist_env_and_memory,
graph_capture,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.layernorm import RMSNorm # noqa
from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as SGLANG_FP8_DTYPE
from sglang.srt.layers.quantization.fp8_kernel import static_quant_fp8
try:
from sgl_kernel import fused_add_rmsnorm as SGL_FUSED_ADD_RMS_NORM
from sgl_kernel import rmsnorm as SGL_RMS_NORM
from sgl_kernel import scaled_fp4_quant as SGL_SCALED_FP4_QUANT
except Exception: # pragma: no cover - fallback on non-supported platforms
SGL_FUSED_ADD_RMS_NORM = None
SGL_RMS_NORM = None
SGL_SCALED_FP4_QUANT = None
FP8_DTYPE = SGLANG_FP8_DTYPE
logger = logging.getLogger(__name__)
# Try to import FlashInfer
try:
import flashinfer.comm as flashinfer_comm # type: ignore
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm = None
logger.warning(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except ImportError:
flashinfer_comm = None
logger.warning("FlashInfer not found, only benchmarking standard operations")
# Constants
MiB = 1024 * 1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: 64 * MiB, # 64MB
8: 64 * MiB, # 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR = None
def setup_flashinfer_workspace(
world_size: int,
rank: int,
hidden_dim: int,
max_token_num: int,
use_fp32_lamport: bool = False,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE_TENSOR
if flashinfer_comm is None:
return None, None
if world_size not in _FI_MAX_SIZES:
logger.warning("FlashInfer not supported for world size %s", world_size)
return None, None
try:
# Create IPC workspace
ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=world_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=get_tp_group().device_group,
use_fp32_lamport=use_fp32_lamport,
)
)
_FI_WORKSPACE_TENSOR = workspace_tensor
return ipc_handles, workspace_tensor
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
return None, None
def cleanup_flashinfer_workspace(ipc_handles):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or ipc_handles is None:
return
try:
group = get_tp_group().device_group
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
}
def flashinfer_fused_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
allreduce_params: "FlashInferFusedAllReduceParams",
use_oneshot: bool,
norm_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=None,
scale_factor=None,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
use_oneshot: bool = True,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=None,
layout_code=None,
scale_factor=scale_factor,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
quant_out: torch.Tensor,
use_oneshot: bool,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=output_scale,
layout_code=None,
scale_factor=input_global_scale,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def standard_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm operations."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm
if residual is not None:
# Fused add + RMS norm (in-place on allreduce_out)
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
rms.forward_native(allreduce_out, residual)
else:
# Just RMS norm
if SGL_RMS_NORM is not None:
_ = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
_ = rms.forward_native(allreduce_out)
def standard_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP8 quantization."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm + static FP8 quantization
if residual is not None:
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
quant_out, _ = static_quant_fp8(
allreduce_out, scale_factor, repeat_scale=False
)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
normed, _ = rms.forward_native(allreduce_out, residual)
quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)
return quant_out, residual
else:
if SGL_RMS_NORM is not None:
normed = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
normed = rms.forward_native(allreduce_out)
quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)
return quant_out
def standard_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP4 quantization."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm
if residual is not None:
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
quant_input = allreduce_out
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
quant_input, _ = rms.forward_native(allreduce_out, residual)
residual_out = residual
else:
if SGL_RMS_NORM is not None:
quant_input = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
quant_input = rms.forward_native(allreduce_out)
residual_out = allreduce_out
# Finally FP4 quantization
if SGL_SCALED_FP4_QUANT is None:
raise RuntimeError("scaled_fp4_quant is not available on this platform")
quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)
if residual is not None:
return quant_res, residual_out, output_scale_res
else:
return quant_res, quant_input
def standard_allreduce_rmsnorm_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm operations using native RMSNorm forward."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
result = rmsnorm_layer.forward_native(allreduce_out, residual)
return result # Returns (norm_out, residual_out)
else:
result = rmsnorm_layer.forward_native(allreduce_out)
return result # Returns norm_out
def standard_allreduce_rmsnorm_fp8_quant_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP8 quantization using native implementations."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)
else:
norm_out = rmsnorm_layer.forward_native(allreduce_out)
residual_out = allreduce_out
# Apply native FP8 quantization
quant_out, _ = static_quant_fp8(norm_out, scale_factor, repeat_scale=False)
if residual is not None:
return quant_out, residual_out
else:
return quant_out
def standard_allreduce_rmsnorm_fp4_quant_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)
quant_input = norm_out
else:
norm_out = rmsnorm_layer.forward_native(allreduce_out)
quant_input = norm_out
residual_out = allreduce_out
# Apply FP4 quantization (still using fused CUDA op as there's no native FP4)
if SGL_SCALED_FP4_QUANT is None:
raise RuntimeError("scaled_fp4_quant is not available on this platform")
quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)
if residual is not None:
return quant_res, residual_out, output_scale_res
else:
return quant_res, norm_out
# Compiled versions of native functions
@torch.compile
def standard_allreduce_rmsnorm_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
norm_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm."""
return standard_allreduce_rmsnorm_native(
input_tensor, residual, rmsnorm_layer, norm_out
)
@torch.compile
def standard_allreduce_rmsnorm_fp8_quant_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm + FP8 quantization."""
return standard_allreduce_rmsnorm_fp8_quant_native(
input_tensor,
residual,
rmsnorm_layer,
scale_factor,
norm_out,
quant_out,
)
@torch.compile
def standard_allreduce_rmsnorm_fp4_quant_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm + FP4 quantization."""
return standard_allreduce_rmsnorm_fp4_quant_native(
input_tensor,
residual,
rmsnorm_layer,
input_global_scale,
quant_out,
output_scale,
norm_out,
)
def create_test_tensors(
seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True
):
"""Create test tensors for benchmarking."""
input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype)
residual = (
torch.randn_like(input_tensor)
if use_residual
else torch.zeros_like(input_tensor)
)
rms_gamma = torch.ones(hidden_dim, dtype=dtype)
norm_out = None if use_residual else torch.empty_like(input_tensor)
# Quantization scales
scale_fp8 = torch.tensor(1.0, dtype=torch.float32)
scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8)
fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)
return (
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
)
def benchmark_operation(
operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for _ in range(warmup):
operation_func(*args, **kwargs)
torch.cuda.synchronize()
# Create CUDA graph
graph = torch.cuda.CUDAGraph()
num_op_per_cudagraph = 10
# Use sglang's graph_capture to make tensor_model_parallel_all_reduce graph-safe
with graph_capture() as graph_capture_context:
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(num_op_per_cudagraph):
operation_func(*args, **kwargs)
# Graph warmup
torch.cuda.synchronize()
for _ in range(warmup):
graph.replay()
# Benchmark with CUDA graph
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(trials // num_op_per_cudagraph):
# operation_func(*args, **kwargs)
graph.replay()
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time_ms = ((end_time - start_time) / trials) * 1000
return avg_time_ms
def run_benchmarks(
seq_len: int,
hidden_dim: int,
dtype: torch.dtype,
use_residual: bool,
allreduce_params: Optional[FlashInferFusedAllReduceParams],
quant_mode: str = "all",
disable_oneshot: bool = False,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual)
rms_eps = 1e-6
results = {}
# Create RMSNorm once for native benchmarks
rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps)
rmsnorm_layer.weight.data = rms_gamma
if quant_mode in ["all", "none"]:
# Standard AllReduce + RMSNorm
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
)
results["standard_allreduce_rmsnorm"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm failed: %s", e)
results["standard_allreduce_rmsnorm"] = float("inf")
# Standard AllReduce + RMSNorm Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
norm_out=norm_out,
)
results["standard_allreduce_rmsnorm_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms
except Exception as e:
logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e)
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Two-shot
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e
)
results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf")
if quant_mode in ["all", "fp8_only"]:
# Standard AllReduce + RMSNorm + FP8 Quant
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
# quant_fp8_layer removed in sglang version; static_quant_fp8 is used within the function
scale_factor=scale_fp8,
norm_out=norm_out,
quant_out=quant_out_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float(
"inf"
)
if quant_mode in ["all", "fp4_only"]:
# Standard AllReduce + RMSNorm + FP4 Quant
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
)
results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e)
results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
norm_out=norm_out,
)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if flashinfer_comm is not None and allreduce_params is not None:
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
"inf"
)
return results
def prepare_results_with_speedups(results_dict):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results = []
# Determine the fastest baseline for each operation type
def get_fastest_baseline(op_name, results_dict):
"""Get the fastest baseline between standard and native_compiled versions."""
if "fp8_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp8_quant",
"standard_allreduce_rmsnorm_fp8_quant_native_compiled",
]
elif "fp4_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp4_quant",
"standard_allreduce_rmsnorm_fp4_quant_native_compiled",
]
else:
candidates = [
"standard_allreduce_rmsnorm",
"standard_allreduce_rmsnorm_native_compiled",
]
# Find the fastest among available candidates
fastest_time = float("inf")
fastest_baseline = None
for candidate in candidates:
if (
candidate in results_dict
and results_dict[candidate] != float("inf")
and results_dict[candidate] < fastest_time
):
fastest_time = results_dict[candidate]
fastest_baseline = candidate
return fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping = {}
for op_name in results_dict:
if (
op_name.startswith("flashinfer_")
or op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
dynamic_baseline_mapping[op_name] = get_fastest_baseline(
op_name, results_dict
)
for op_name, time_ms in results_dict.items():
if time_ms == float("inf"):
speedup_str = "FAILED"
time_str = "FAILED"
else:
time_str = f"{time_ms:.3f}"
# Find the appropriate baseline for this operation
baseline_op = dynamic_baseline_mapping.get(op_name)
if baseline_op and baseline_op in results_dict:
baseline_time = results_dict[baseline_op]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
# For baseline operations, determine if this is the fastest baseline
if op_name.endswith("_native_compiled") or (
op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
fastest_baseline = get_fastest_baseline(op_name, results_dict)
if fastest_baseline == op_name:
speedup_str = "baseline"
else:
if fastest_baseline and fastest_baseline in results_dict:
baseline_time = results_dict[fastest_baseline]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
prepared_results.append(
{
"operation": op_name,
"time_ms": time_ms,
"time_str": time_str,
"speedup_str": speedup_str,
}
)
return prepared_results
def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode):
"""Print benchmark results in a formatted table."""
print(f"\n{'=' * 80}")
print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}")
print(
f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, "
f"quant_mode={quant_mode}"
)
print(f"{'=' * 80}")
print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}")
print(f"{'-' * 80}")
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
if result["time_ms"] == float("inf"):
time_display = result["time_str"]
else:
time_display = f"{result['time_ms']:.3f}"
print(
f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}"
)
def format_results_markdown(
all_results: list[dict], world_size: int, args: argparse.Namespace
) -> str:
"""Format all benchmark results as markdown."""
markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results
**World Size:** {world_size}
**Hidden Dimension:** {args.hidden_dim}
**Warmup Iterations:** {args.warmup}
**Benchmark Trials:** {args.trials}
**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"}
---
"""
for result in all_results:
seq_len = result["seq_len"]
dtype = result["dtype"]
use_residual = result["use_residual"]
results_dict = result["results"]
residual_str = "with residual" if use_residual else "no residual"
markdown += f"""
## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str}
| Operation | Time (ms) | Speedup |
|-----------|-----------|---------|
"""
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
# Format operation name for better readability
formatted_op_name = result["operation"].replace("_", " ").title()
markdown += f"| {formatted_op_name} | {result['time_str']} |"
markdown += f"{result['speedup_str']} |\n"
markdown += "\n"
return markdown
def save_results_to_file(
all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if rank != 0:
return
if not all_results:
logger.warning("No results to save")
return
output_path = args.output_file
try:
markdown_content = format_results_markdown(all_results, world_size, args)
with open(output_path, "w") as f:
f.write(markdown_content)
except Exception as e:
logger.error("Failed to save results to file: %s", e)
def main():
parser = argparse.ArgumentParser(
description="Benchmark fused collective operations"
)
parser.add_argument(
"--seq-lens",
type=int,
nargs="+",
default=[128, 512, 1024, 2048],
help="Sequence lengths to test",
)
parser.add_argument(
"--hidden-dim", type=int, default=8192, help="Hidden dimension size"
)
parser.add_argument(
"--dtypes",
type=str,
nargs="+",
default=["bfloat16"],
choices=["float16", "bfloat16", "float32"],
help="Data types to test",
)
parser.add_argument(
"--no-residual",
action="store_true",
help="Skip residual connection tests",
)
# Quantization mode options (mutually exclusive with --no-quant)
quant_group = parser.add_mutually_exclusive_group()
quant_group.add_argument(
"--no-quant", action="store_true", help="Skip all quantization tests"
)
quant_group.add_argument(
"--quant-fp8", action="store_true", help="Only run FP8 quantization tests"
)
quant_group.add_argument(
"--quant-fp4", action="store_true", help="Only run FP4 quantization tests"
)
quant_group.add_argument(
"--quant-all",
action="store_true",
help="Run all quantization tests (default)",
)
parser.add_argument(
"--disable-oneshot",
action="store_true",
help="Disable oneshot mode for FlashInfer operations",
)
parser.add_argument(
"--warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--trials", type=int, default=20, help="Number of benchmark trials"
)
parser.add_argument(
"--output-file",
type=str,
help="""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
""",
)
args = parser.parse_args()
# Check if running with torchrun (required for collective operations)
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=rank,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Validate world size (must be > 1 for collective operations)
if world_size <= 1:
raise ValueError(
"World size must be > 1 for collective operations benchmarking. "
f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1."
)
# Determine quantization mode
if args.no_quant:
quant_mode = "none"
elif args.quant_fp8:
quant_mode = "fp8_only"
elif args.quant_fp4:
quant_mode = "fp4_only"
else: # args.quant_all or default
quant_mode = "all"
if rank == 0:
logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank)
logger.info("Quantization mode: %s", quant_mode)
if flashinfer_comm is not None:
oneshot_status = "enabled" if not args.disable_oneshot else "disabled"
logger.info(
"FlashInfer available - will benchmark fused operations (oneshot: %s)",
oneshot_status,
)
else:
logger.info(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtypes = [dtype_map[dt] for dt in args.dtypes]
# Test configurations
residual_options = [True] if not args.no_residual else [False]
if not args.no_residual:
residual_options.append(False)
configs = list(itertools.product(args.seq_lens, dtypes, residual_options))
# Setup FlashInfer workspace if available
ipc_handles = None
allreduce_params = None
if flashinfer_comm is not None:
# Use the largest hidden dimension for workspace setup
max_num_token = _FI_MAX_SIZES.get(world_size) // (
args.hidden_dim * world_size * 2
)
ipc_handles, workspace_tensor = setup_flashinfer_workspace(
world_size, rank, args.hidden_dim, max_num_token
)
if workspace_tensor is not None:
allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=world_size,
max_token_num=max_num_token,
)
# Collect all results for markdown export
all_results = []
try:
# Run benchmarks
for seq_len, dtype, use_residual in configs:
if rank == 0:
logger.info(
"\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s",
seq_len,
args.hidden_dim,
dtype,
use_residual,
)
results = run_benchmarks(
seq_len,
args.hidden_dim,
dtype,
use_residual,
allreduce_params,
quant_mode=quant_mode,
disable_oneshot=args.disable_oneshot,
)
# Store results for markdown export
if rank == 0:
all_results.append(
{
"seq_len": seq_len,
"hidden_dim": args.hidden_dim,
"dtype": str(dtype).replace("torch.", ""),
"use_residual": use_residual,
"quant_mode": quant_mode,
"results": results,
}
)
print_results(
results,
seq_len,
args.hidden_dim,
dtype,
use_residual,
quant_mode,
)
# Save results to markdown file
if args.output_file and rank == 0:
save_results_to_file(all_results, world_size, args, rank)
finally:
# Cleanup
if ipc_handles is not None:
cleanup_flashinfer_workspace(ipc_handles)
with contextlib.suppress(Exception):
dist.barrier()
cleanup_dist_env_and_memory(shutdown_ray=False)
if __name__ == "__main__":
main()
## Tuning Triton MoE Kernels
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
### Tuning Tool
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
Example usage:
```bash
# Tune Mixtral-8x7B with default settings
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tune
# Tune Qwen2-57B with FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune Qwen3-235B-A22B-FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-235B-A22B-FP8 \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-V3 with FP8 and TP=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model meituan/DeepSeek-R1-Channel-INT8 \
--tp-size 16 \
--dtype int8_w8a8 \
--tune
```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
### Performance Comparison Tool
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
Example usage:
```bash
# Compare with default settings (Mixtral model)
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
# Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--use-fp8-w8a8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
import argparse
import torch
import triton
from transformers import AutoConfig
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_triton_api(
x,
w1,
w2,
input_gating,
topk,
):
topk_op = TopK(
top_k=topk,
renormalize=False,
use_grouped_topk=False,
)
topk_op.use_triton_kernels = True
triton_topk_output = topk_op.forward_cuda(
hidden_states=x,
router_logits=input_gating,
)
moe_runner_config = MoeRunnerConfig(
inplace=False,
)
return triton_kernel_moe_forward(
x,
w1,
w2,
triton_topk_output,
moe_runner_config,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
topk_output = select_experts(
hidden_states=x,
router_logits=input_gating,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
return fused_moe_sglang(
x,
w1,
w2,
topk_output,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
line_arg="provider",
line_vals=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
line_names=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(
batch_size,
provider,
model_config,
use_fp8_w8a8=False,
use_cuda_graph: bool = False,
):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_tri = w1.clone()
w2_tri = w2.clone()
w1_tri = w1_tri.transpose(-2, -1).contiguous()
w2_tri = w2_tri.transpose(-2, -1).contiguous()
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
if provider == "sglang_fused_moe_triton_v340":
api_func = fused_moe_triton_api
api_kwargs = {
"x": x,
"w1": w1_tri,
"w2": w2_tri,
"input_gating": input_gating,
"topk": topk,
}
else:
api_func = fused_moe_sglang_api
api_kwargs = {
"x": x,
"w1": w1,
"w2": w2,
"input_gating": input_gating,
"topk": topk,
"use_fp8_w8a8": use_fp8_w8a8,
"block_shape": block_shape,
}
# Warmup
for _ in range(10):
_ = api_func(**api_kwargs)
torch.cuda.synchronize()
if use_cuda_graph:
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
api_func(**api_kwargs)
torch.cuda.synchronize()
bench_lambda = lambda: graph.replay()
else:
bench_lambda = lambda: api_func(**api_kwargs)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/sglang_fused_moe/",
)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
use_cuda_graph=args.use_cuda_graph,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()
import torch
import triton
import triton.language as tl
from triton.testing import do_bench
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 1
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def compute_sum_scaled_baseline(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
return out
@torch.compile
def compute_sum_scaled_compiled(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x * routed_scaling_factor, dim=1, out=out)
return out
def get_benchmark():
num_tokens_range = [2**i for i in range(0, 13)]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=num_tokens_range,
line_arg="version",
line_vals=["baseline", "compiled", "triton"],
line_names=["Original", "TorchCompile", "TritonKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="sum_scaled_performance",
args={},
)
)
def benchmark(num_tokens, version):
topk = 9
hidden_size = 4096
dtype = torch.bfloat16
scaling_factor = 0.3
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
# Warmup
for _ in range(3):
if version == "baseline":
compute_sum_scaled_baseline(x, out, scaling_factor)
elif version == "compiled":
compute_sum_scaled_compiled(x, out, scaling_factor)
else:
moe_sum_reduce(x, out, scaling_factor)
# Benchmark
quantiles = [0.5, 0.2, 0.8]
if version == "baseline":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
quantiles=quantiles,
)
elif version == "compiled":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def verify_correctness(num_tokens=1024):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
scaling_factor = 0.3
out_baseline = torch.empty_like(x[:, 0])
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
out_compiled = torch.empty_like(out_baseline)
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
out_triton = torch.empty_like(out_baseline)
moe_sum_reduce(x, out_triton, scaling_factor)
if torch.allclose(
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
print(
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
)
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
if __name__ == "__main__":
print("Running correctness verification...")
verify_correctness()
print("\nRunning performance benchmark...")
benchmark = get_benchmark()
benchmark.run(
print_data=True,
# save_path="./configs/benchmark_ops/sum_scaled/"
)
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
import triton
from torch.nn import functional as F
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@torch.compile(dynamic=False)
def fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
) -> torch.Tensor:
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
topk_weights, topk_ids = fused_topk_native(
hidden_states=x,
gating_output=input_gating,
topk=topk,
renormalize=True,
)
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
def fused_moe_torch_compile(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_triton(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 5)),
line_arg="provider",
line_vals=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
line_names=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
set_torch_compile_config()
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_scale = w2_scale = a1_scale = a2_scale = None
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_torch_compile
if provider == "fused_moe_torch_compile"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/fused_moe_torch_compile/",
)
args = parser.parse_args()
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
if __name__ == "__main__":
main()
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
import triton
import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_vllm_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
if block_shape is not None:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
else:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
return fused_moe_sglang(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 513)),
line_arg="provider",
line_vals=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
line_names=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_vllm_api
if provider == "vllm_fused_moe_triton"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
)
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse
import json
import time
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
get_default_config,
get_moe_configs,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8:
w1 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn(
num_experts, shard_intermediate_size, dtype=torch.float32
)
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
elif block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(x, input_gating, topk_config)
def prepare(i: int):
input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits)
def run():
moe_runner_config = MoeRunnerConfig(
inplace=True,
)
with override_config(config):
fused_moe(
x,
w1,
w2,
topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
)
if op_config is None:
config = get_default_config(
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype_str,
False,
block_shape,
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
return config, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
)
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
if args.tune:
search_space = get_configs_compute_bound()
if block_shape is not None:
block_n, block_k = block_shape[0], block_shape[1]
search_space = [
config
for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
start = time.perf_counter()
configs = _distribute(
"tune",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
)
for batch_size in batch_sizes
],
)
best_configs = {
M: sort_config(config) for M, config in zip(batch_sizes, configs)
}
save_configs(
best_configs,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
for batch_size in batch_sizes
],
)
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
print(f"Kernel time: {kernel_time:.2f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument(
"--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
args = parser.parse_args()
main(args)
import itertools
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
@triton.jit
def _decode_kernel(
Q,
K,
V,
KV,
Out,
S,
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
kv_offset = off_bh * d * e
s = tl.load(S + off_h)
ratio = tl.exp(-s)
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def lightning_attn_decode(q, k, v, kv, s):
"""Triton implementation of Lightning Attention decode operation"""
b, h, n, d = q.shape
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
q_padded,
k_padded,
v_padded,
kv_padded,
o_padded,
s,
b=b,
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
return o, kv_out
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
class MiniMaxText01LightningAttention(nn.Module):
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
super().__init__()
if config is None:
config = type("Config", (), kwargs)
bias = False
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.out_proj = nn.Linear(
self.head_dim * self.num_heads, self.hidden_size, bias=bias
)
self.act = get_activation_fn(config.hidden_act)
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
self.qkv_proj = nn.Linear(
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
)
self.output_gate = nn.Linear(
self.hidden_size, self.head_dim * self.num_heads, bias=bias
)
# for inference only
self.offset = 0
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
**kwargs,
):
if (not self.training) and (not do_eval):
return self.inference(
hidden_states,
attn_mask,
output_attentions,
past_key_value,
use_cache,
slope_rate,
)
def inference(
self,
x,
attn_mask: Optional[torch.Tensor] = None, # (b, n)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
):
# x: b n d
b, n, d = x.shape
# linear map
qkv = self.act(self.qkv_proj(x))
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e]
self.offset += 1
ratio = torch.exp(-slope_rate) # [h, 1, 1]
# decode mode
kv = past_key_value # [b, h, d, e]
output = []
for i in range(n):
# kv: [b, h, d, e]
# ratio: [h, 1, 1]
# k: [b, h, n, d]
# v: [b, h, n, e]
# k[:, :, i : i + 1]: [b, h, 1, d]
# v[:, :, i : i + 1]: [b, h, 1, e]
# ratio * kv: [b, h, d, e]
# torch.einsum(
# "... n d, ... n e -> ... d e",
# k[:, :, i : i + 1],
# v[:, :, i : i + 1],
# )
# [b, h, d, e] + [b, h, d, e] -> [b, h, d, e]
kv = ratio * kv + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
# q[:, :, i : i + 1]: [b, h, 1, d]
# kv.to(q.dtype): [b, h, d, e]
# torch.einsum(
# "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
# )
# [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e]
qkv = torch.einsum(
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
output = self.norm(output)
# gate
output = F.sigmoid(self.output_gate(x)) * output
# outproj
output = self.out_proj(output)
attn_weights = None
return output, attn_weights, kv
def get_activation_fn(activation):
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
return lambda x: x
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def test_lightning_attention_implementations(model_params):
torch.manual_seed(42)
batch_size = 64
seq_len = 1
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_states = torch.randn(
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
model_attn.eval()
d = model_params["head_dim"]
past_kv = torch.randn(
batch_size,
model_params["num_attention_heads"],
d,
d,
device=device,
)
with torch.no_grad():
model_output, _, new_kv = model_attn.inference(
hidden_states,
attn_mask=attention_mask,
slope_rate=slope_rate,
past_key_value=past_kv,
)
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
past_kv = past_kv.contiguous()
slope_rate = slope_rate.contiguous()
# Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1)
triton_output = model_attn.norm(triton_output)
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output)
# Test SGL implementation
sgl_output = torch.empty_like(v)
sgl_new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
sgl_output = sgl_output.transpose(1, 2).contiguous()
sgl_output = sgl_output.view(batch_size, seq_len, -1)
sgl_output = model_attn.norm(sgl_output)
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
sgl_output = model_attn.out_proj(sgl_output)
# Verify Triton implementation results
torch.testing.assert_close(
model_output,
triton_output,
rtol=1e-3,
atol=1e-2,
msg="Triton lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
triton_new_kv,
rtol=1e-3,
atol=1e-2,
msg="Triton lightning attention implementation produces different kv results",
)
# Verify SGL implementation results
torch.testing.assert_close(
model_output,
sgl_output,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
sgl_new_kv,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different kv results",
)
print("✅ All implementations match")
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
n_attention_heads, 1, 1
)
return slopes
def get_benchmark():
batch_size_range = [i for i in range(1, 33)] # max 32
seq_length_range = [1] # decode mode sequence length is fixed to 1
configs = list(itertools.product(batch_size_range, seq_length_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["Original", "Triton", "SGL"],
line_names=[
"Original PyTorch Implementation",
"Triton Implementation",
"SGL Implementation",
],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "gelu",
}
hidden_states = torch.randn(
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
model_attn.eval()
d = params["head_dim"]
past_kv = torch.randn(
batch_size,
params["num_attention_heads"],
d,
d,
device=device,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "Original":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: model_attn.inference(
hidden_states,
attn_mask=attention_mask,
slope_rate=slope_rate,
past_key_value=past_kv,
),
quantiles=quantiles,
)
elif provider == "Triton":
def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_triton,
quantiles=quantiles,
)
else: # SGL
def run_sgl():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
output = torch.empty_like(v)
new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(
q, k, v, past_kv, slope_rate, output, new_kv
)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_sgl,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_decode/",
help="Path to save lightning attention decode benchmark results",
)
args = parser.parse_args()
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "silu",
}
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
test_lightning_attention_implementations(params)
# Run performance benchmark
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)
import itertools
import math
import os
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@triton.jit
def _fwd_kernel(
Q,
K,
V,
Out,
S, # log lambda
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK: tl.constexpr,
BLOCK_MODEL: tl.constexpr,
):
##### get offset
off_bh = tl.program_id(0)
off_h = off_bh % h
off_e = tl.program_id(1)
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# channel offset
e_offset = off_e * BLOCK_MODEL
##### get block ptr
Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]
K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
S_block_ptr = S + off_h
##### init diag decay(Lambda); q, k decay; kv
s = tl.load(S_block_ptr)
# q, k decay
off_block = tl.arange(
0, BLOCK
) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))
block_decay = tl.exp(-s.to(tl.float32) * BLOCK)
# diag decay
index = off_block[:, None] - off_block[None, :]
s_index = s * index
s_index = tl.where(index >= 0, -s_index, float("-inf"))
diag_decay = tl.exp(s_index)
kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)
##### compute
for i in range(NUM_BLOCK):
# load
q = tl.load(
Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0
).to(tl.float32)
k_trans = tl.load(
K_trans_block_ptr + off_block[None, :] * d,
mask=off_block[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0
).to(tl.float32)
# compute
qk = tl.dot(q, k_trans) * diag_decay
o_intra = tl.dot(qk, v)
o_inter = tl.dot(q, kv) * q_decay
o = o_intra + o_inter
# save and update
tl.store(
O_block_ptr + off_block[:, None] * e,
o.to(O_block_ptr.dtype.element_ty),
mask=off_block[:, None] < n,
)
kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)
off_block += BLOCK
def lightning_attn2(q, k, v, s):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
b, h, n, d = q.shape
e = v.shape[-1]
# Pad d to next power of 2
d_padded = next_power_of_2(d)
if d_padded != d:
q_padded = F.pad(q, (0, d_padded - d))
k_padded = F.pad(k, (0, d_padded - d))
else:
q_padded = q
k_padded = k
# Pad e to next power of 2
e_padded = next_power_of_2(e)
if e_padded != e:
v_padded = F.pad(v, (0, e_padded - e))
else:
v_padded = v
o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device)
BLOCK = 64
NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
# parallel over channel
BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32)
grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL))
_fwd_kernel[grid](
q_padded,
k_padded,
v_padded,
o_padded,
s,
b,
h,
n,
d_padded,
e_padded,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
BLOCK_MODEL=BLOCK_MODEL,
)
# Remove padding from output
if e_padded != e:
o = o_padded[..., :e]
else:
o = o_padded
return o
def is_support(dim):
return 16 % dim
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
def lightning_attn_func(q, k, v, s):
b, h, n, d = q.shape
e = v.shape[-1]
assert is_support(d) and is_support(e)
# pad v's feature dim to power of 2
e_pad = next_power_of_2(e)
need_pad = e_pad != e
if need_pad:
v = F.pad(v, (0, e_pad - e))
if d > 128:
# split over head
if 64 % d:
m = 64
elif 32 % d:
m = 32
elif 16 % d:
m = 16
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
o = 0
for i in range(n - 1):
start = arr[i]
end = arr[i + 1]
q1 = q[..., start:end]
k1 = k[..., start:end]
o += lightning_attn2(q1, k1, v, s)
else:
o = lightning_attn2(q, k, v, s)
if need_pad:
o = o[:, :, :, :e]
return o
debug = eval(os.environ.get("debug", default="False"))
BLOCK = 256
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
def get_activation_fn(activation):
if debug:
logger.info(f"activation: {activation}")
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
logger.info(f"activation: does not support {activation}, use Identity!!!")
return lambda x: x
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
class MiniMaxText01LightningAttention(nn.Module):
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
super().__init__()
if config is None:
config = type("Config", (), kwargs)
bias = False
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.out_proj = nn.Linear(
self.head_dim * self.num_heads, self.hidden_size, bias=bias
)
self.act = get_activation_fn(config.hidden_act)
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
self.qkv_proj = nn.Linear(
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
)
self.output_gate = nn.Linear(
self.hidden_size, self.head_dim * self.num_heads, bias=bias
)
# for inference only
self.offset = 0
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
**kwargs,
):
if (not self.training) and (not do_eval):
return self.inference(
hidden_states,
attn_mask,
output_attentions,
past_key_value,
use_cache,
slope_rate,
)
def inference(
self,
x,
attn_mask: Optional[torch.Tensor] = None, # (b, n)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
):
# x: b n d
b, n, d = x.shape
# linear map
qkv = self.act(self.qkv_proj(x))
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if past_key_value is None:
self.offset = q.shape[-2]
else:
self.offset += 1
# for align with metaseq
ratio = torch.exp(-slope_rate)
# only use for the first time
if past_key_value is None:
slope_rate = slope_rate.to(torch.float32)
if attn_mask is not None:
v = v.masked_fill(
(1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
)
NUM_BLOCK = (n + BLOCK - 1) // BLOCK
b, h, n, d = q.shape
e = v.shape[-1]
# other
array = torch.arange(BLOCK).to(q) + 1
q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
index = array[:, None] - array[None, :]
s_index = (
slope_rate
* index[
None,
None,
]
)
s_index = torch.where(index >= 0, -s_index, float("-inf"))
diag_decay = torch.exp(s_index)
kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
for i in range(NUM_BLOCK):
si = i * BLOCK
ei = min(si + BLOCK, n)
m = ei - si
qi = q[:, :, si:ei].contiguous()
ki = k[:, :, si:ei].contiguous()
vi = v[:, :, si:ei].contiguous()
qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
# diag
qk = (
torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32)
* diag_decay[:, :, :m, :m]
)
qkv_diag = torch.matmul(qk, vi.to(torch.float32))
block_decay = torch.exp(-slope_rate * m)
output[:, :, si:ei] = qkv_none_diag + qkv_diag
kv = block_decay * kv + torch.matmul(
(ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi
)
else:
kv = past_key_value
output = []
for i in range(n):
kv = ratio * kv + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
output = self.norm(output)
# gate
output = F.sigmoid(self.output_gate(x)) * output
# outproj
output = self.out_proj(output)
attn_weights = None
return output, attn_weights, kv
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(
n
) # In the paper, we only train models that have 2^a heads for some a. This function has
else: # some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2 ** math.floor(
math.log2(n)
) # when the number of heads is not a power of 2, we use this workaround.
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
# h, 1, 1
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
n_attention_heads, 1, 1
)
return slopes
def test_lightning_attention_implementations(model_params):
torch.manual_seed(42)
batch_size = 2
seq_len = 1024
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_states = torch.randn(
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
model_attn.eval()
with torch.no_grad():
model_output, _, _ = model_attn.inference(
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
)
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
lib_output = lightning_attn_func(q, k, v, slope_rate)
lib_output = lib_output.transpose(1, 2).contiguous()
lib_output = lib_output.view(batch_size, seq_len, -1)
lib_output = model_attn.norm(lib_output)
lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
lib_output = model_attn.out_proj(lib_output)
torch.testing.assert_close(
model_output,
lib_output,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different results",
)
print("✅ Two implementations match")
def get_benchmark():
batch_size_range = [2**i for i in range(0, 7)] # max 64
seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096
configs = list(itertools.product(batch_size_range, seq_length_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["MiniMax-Text-01", "OpenNLPLab"],
line_names=[
"MiniMax-Text-01 Model Implementation",
"OpenNLPLab Library Implementation",
],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="lightning-attention-prefill-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "gelu",
}
hidden_states = torch.randn(
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
model_attn.eval()
quantiles = [0.5, 0.2, 0.8]
if provider == "MiniMax-Text-01":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: model_attn.inference(
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
),
quantiles=quantiles,
)
else:
def run_lib():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
lib_output = lightning_attn_func(q, k, v, slope_rate)
lib_output = lib_output.transpose(1, 2).contiguous()
lib_output = lib_output.view(batch_size, seq_len, -1)
lib_output = model_attn.norm(lib_output)
lib_output = (
torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
)
return model_attn.out_proj(lib_output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_lib,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_prefill/",
help="Path to save lightning attention prefill benchmark results",
)
args = parser.parse_args()
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "silu",
}
test_lightning_attention_implementations(params)
# Run performance benchmark
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)
import argparse
import itertools
import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper
def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
)
torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "triton_fp8":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
x,
glb_scales,
masks,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
import argparse
import torch
import triton
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
@torch.compile(backend="inductor")
def torch_int8_quant(x):
int8_max = torch.iinfo(torch.int8).max
abs_max = x.abs().max(dim=-1, keepdim=True).values
scales = abs_max.to(torch.float32) / float(int8_max)
q_x = (x / scales).round().to(torch.int8)
return q_x, scales
def _test_accuracy_once(M, K, input_dtype, device):
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
out1, scales1 = per_token_quant_int8(x)
out2, scales2 = torch_int8_quant(x)
torch.testing.assert_close(out, out2, atol=1, rtol=0)
torch.testing.assert_close(out, out1, atol=1, rtol=0)
torch.testing.assert_close(scales, scales2)
torch.testing.assert_close(scales1, scales2)
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
def test_accuracy():
Ms = [1, 13, 128, 1024, 2048, 4096]
Ks = [512, 1024, 2048, 8192]
input_dtypes = [torch.float16, torch.bfloat16]
for M in Ms:
for K in Ks:
for input_dtype in input_dtypes:
_test_accuracy_once(M, K, input_dtype, "cuda")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm op", "triton", "torch.compile"],
line_names=["vllm op", "triton", "torch.compile"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
ylabel="ms",
plot_name="int8 per token quant",
args={},
)
)
def benchmark(batch_size, provider):
M, K = batch_size, 16384
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm op":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_int8_quant(x, symmetric=True),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: per_token_quant_int8(x),
quantiles=quantiles,
)
if provider == "torch.compile":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_int8_quant(x),
quantiles=quantiles,
)
return ms, min_ms, max_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_int8_quant_res",
help="Path to save int8 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import json
import multiprocessing as mp
import os
import time
from datetime import datetime
from typing import Any, Dict, List
import torch
import triton
from tqdm import tqdm
mp.set_start_method("spawn", force=True)
from sglang.srt.layers.quantization.fp8_kernel import (
_w8a8_block_fp8_matmul,
_w8a8_block_fp8_matmul_unrolledx4,
)
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
_is_hip = is_hip()
DTYPE_MAP = {
"float32": torch.float32,
"float16": torch.float16,
"half": torch.half,
"bfloat16": torch.bfloat16,
}
def w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
config: Dict[str, Any],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
else:
kernel = _w8a8_block_int8_matmul
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
def get_rocm_configs_compute_bound():
configs = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound():
configs = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def get_weight_shapes(tp_size):
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def benchmark_config(
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
):
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
# JIT complication & warmup
for _ in range(5):
run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
start_event.record()
run()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
return avg
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
factor_for_scale = 1e-2
if input_type == "fp8":
fp8_info = torch.finfo(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
else:
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
Bs = (
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
* factor_for_scale
)
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
A,
B,
As,
Bs,
block_size,
config,
out_dtype,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
assert best_config is not None
return best_config
def save_configs(
N,
K,
block_n,
block_k,
configs,
save_path,
input_type="fp8",
) -> None:
os.makedirs(save_path, exist_ok=True)
device_name = get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
with open(config_file_path, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_available_gpu_count():
"""Get the number of available GPUs."""
return torch.cuda.device_count()
def tune_on_gpu(args_dict):
"""Run tuning on a specific GPU."""
gpu_id = args_dict["gpu_id"]
batch_sizes = args_dict["batch_sizes"]
weight_shapes = args_dict["weight_shapes"]
args = args_dict["args"]
torch.cuda.set_device(gpu_id)
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
block_n = args.block_n
block_k = args.block_k
out_dtype = DTYPE_MAP[args.out_dtype]
save_path = args.save_path
input_type = args.input_type
search_space = get_configs_compute_bound()
search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
start = time.perf_counter()
results = {}
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
N, K = shape[0], shape[1]
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
benchmark_results = [
tune(
batch_size,
N,
K,
[block_n, block_k],
out_dtype,
search_space,
input_type,
)
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.perf_counter()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
def distribute_batch_sizes(batch_sizes, num_gpus):
"""Distribute batch sizes across available GPUs."""
batches_per_gpu = []
for i in range(num_gpus):
start_idx = i * len(batch_sizes) // num_gpus
end_idx = (i + 1) * len(batch_sizes) // num_gpus
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
return batches_per_gpu
def main(args):
print(args)
num_gpus = get_available_gpu_count()
if num_gpus == 0:
raise RuntimeError("No GPU available for tuning")
print(f"Found {num_gpus} GPUs for parallel tuning")
torch.cuda.init()
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
num_gpus = 1 # If only one batch size, use only one GPU
weight_shapes = get_weight_shapes(args.tp_size)
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
process_args = []
for gpu_id in range(num_gpus):
process_args.append(
{
"gpu_id": gpu_id,
"batch_sizes": batches_per_gpu[gpu_id],
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args,
}
)
ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool:
pool.map(tune_on_gpu, process_args)
print("Multi-GPU tuning completed")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument(
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
)
parser.add_argument(
"--out-dtype",
type=str,
choices=["float32", "float16", "bfloat16", "half"],
default="float16",
)
parser.add_argument("--block-n", type=int, default=128)
parser.add_argument("--block-k", type=int, default=128)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument(
"--save-path", type=str, default="python/sglang/srt/layers/quantization/configs"
)
args = parser.parse_args()
main(args)
import itertools
from typing import Optional, Tuple, Union
import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from vllm import _custom_ops as vllm_ops
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
naive_norm.weight = nn.Parameter(weight)
naive_norm = naive_norm.to(x.device)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
output = naive_norm(x, residual)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
fused_add_rmsnorm(x, residual, weight, eps)
output = (x, residual)
else:
output = rmsnorm(x, weight, eps)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
output = (x, residual)
else:
out = torch.empty_like(x)
vllm_ops.rms_norm(out, x, weight, eps)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
if torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"],
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
args={},
)
)
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
quantiles = [0.5, 0.2, 0.8]
if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection"
)
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/rmsnorm/",
help="Path to save rmsnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
)
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
import os
import torch
import triton
import triton.language as tl
@torch.compile(dynamic=True)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
def test_get_last_loc():
max_batch = 4097
max_context_len = 6148
batch_size = 20
# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
pre_lens = torch.randint(
-max_context_len // 2,
max_context_len,
(batch_size,),
dtype=torch.int64,
device="cuda",
)
last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)
last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)
# Compare results
torch.testing.assert_close(last_loc_res, last_loc_ref)
def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=batch_sizes,
line_arg="provider",
line_vals=["reference", "triton"],
line_names=["PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="get-last-loc-performance",
args={},
)
)
def benchmark(batch_size, provider):
max_batch = 2048
max_context_len = 16384
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
pre_lens = torch.randint(
-max_context_len // 2,
max_context_len,
(batch_size,),
dtype=torch.int64,
device="cuda",
)
quantiles = [0.5, 0.2, 0.8]
if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"):
"""Run benchmark and save results"""
# Ensure save path exists
os.makedirs(save_path, exist_ok=True)
# Run correctness test
test_get_last_loc()
print("Correctness test passed!")
# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/get_last_loc/",
help="Path to save benchmark results",
)
args = parser.parse_args()
run_benchmark(args.save_path)
import itertools
import os
import torch
import triton
import triton.language as tl
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
@triton.jit
def write_req_to_token_pool_triton_optimize(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_token = tl.program_id(1)
req_pool_index = tl.load(req_pool_indices + pid_batch)
pre_len = tl.load(pre_lens + pid_batch)
seq_len = tl.load(seq_lens + pid_batch)
extend_len = seq_len - pre_len
cumsum_start = 0
for i in range(pid_batch):
cumsum_start += tl.load(extend_lens + i)
token_start = pid_token * BLOCK_SIZE
offset = tl.arange(0, BLOCK_SIZE)
actual_offset = token_start + offset
mask = actual_offset < extend_len
src_ptr = out_cache_loc + cumsum_start + actual_offset
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
value = tl.load(src_ptr, mask=mask)
dst_ptr = (
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ actual_offset
+ pre_len
)
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
tl.store(dst_ptr, value, mask=mask)
def write_req_to_token_pool_reference(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
pre_lens: torch.Tensor,
seq_lens: torch.Tensor,
extend_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
) -> None:
"""Reference implementation using PyTorch"""
for i in range(len(req_pool_indices)):
req_pool_idx = req_pool_indices[i].item()
pre_len = pre_lens[i].item()
seq_len = seq_lens[i].item()
extend_len = extend_lens[i].item()
cumsum_start = sum(extend_lens[:i].tolist())
# Copy values from out_cache_loc to req_to_token
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
cumsum_start : cumsum_start + extend_len
]
def test_write_req_to_token_pool():
max_batch = 4097
max_context_len = 6148
batch_size = 1
extend_len = 14
# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
# Create copies for reference implementation
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()
# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)
# Run optimized triton kernel
def grid(batch_size, extend_len):
num_token_blocks = triton.cdiv(extend_len, 512)
return (batch_size, num_token_blocks)
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)
# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)
# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
# Test case 2: batch size > 1
batch_size = 3
extend_lens_list = [14, 20, 30]
total_extend_len = sum(extend_lens_list)
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()
# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)
# Run optimized triton kernel
max_extend_len = max(extend_lens_list)
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)
# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)
# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(batch_sizes, extend_lens))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "extend_len"],
x_vals=configs,
line_arg="provider",
line_vals=["reference", "triton", "triton_optimize"],
line_names=["PyTorch", "Triton", "Triton Optimized"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="write-req-to-token-pool-performance",
args={},
)
)
def benchmark(batch_size, extend_len, provider):
max_batch = 256
max_context_len = 16384
extend_lens_list = [extend_len] * batch_size
total_extend_len = sum(extend_lens_list)
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
seq_lens = pre_lens + extend_len
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_reference(
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_triton[(batch_size,)](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
),
quantiles=quantiles,
)
else:
def run_optimized():
block_size = 128 if extend_len <= 1024 else 512
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
write_req_to_token_pool_triton_optimize[grid_config](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(
run_optimized, quantiles=quantiles
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
"""Run benchmark and save results"""
# Ensure save path exists
os.makedirs(save_path, exist_ok=True)
# Run correctness test
test_write_req_to_token_pool()
print("Correctness test passed!")
# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/write_req_to_token_pool/",
help="Path to save benchmark results",
)
args = parser.parse_args()
run_benchmark(args.save_path)
import itertools
import torch
import torch.nn.functional as F
import triton.testing as tt
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
def extend_attention_fwd_torch(
q: torch.Tensor, # [extend_tokens, H_Q, D]
k: torch.Tensor, # [extend_tokens, H_KV, D]
v: torch.Tensor, # [extend_tokens, H_KV, D]
o: torch.Tensor, # [extend_tokens, H_Q, D]
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
qo_indptr: torch.Tensor, # [B+1]
kv_indptr: torch.Tensor, # [B+1]
kv_indices: torch.Tensor, # [prefix_tokens]
sliding_window_size: int,
):
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = 1.0 / D**0.5
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
if group_size != 1:
k_full_hq = k_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
v_full_hq = v_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
# causal
pos_keys = torch.arange(total_len, device=q.device)
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
# sliding window
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = (
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
) # [extend_len, H_Q, total_len]
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
def _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
):
b_seq_len_prefix = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len_extend = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
)
for i in range(B):
s = kv_indptr[i].item()
e = kv_indptr[i + 1].item()
kv_indices[s:e] = torch.arange(
b_start_loc[i],
b_start_loc[i] + b_seq_len_prefix[i],
dtype=torch.int32,
device=device,
)
total_token_num = int(torch.sum(b_seq_len).item())
extend_token_num = int(torch.sum(b_seq_len_extend).item())
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
o_extend_triton = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device=device
)
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
inputs = dict(
q_extend=q_extend,
k_extend=k_extend,
v_extend=v_extend,
k_buffer=k_buffer,
v_buffer=v_buffer,
o_extend_triton=o_extend_triton,
o_extend_torch=o_extend_torch,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
max_len_extend=max_len_extend,
WINDOW_SIZE=WINDOW_SIZE,
)
meta = dict(
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
)
return inputs, meta
def _run_triton(inputs):
extend_attention_fwd(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_triton"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=inputs["max_len_extend"],
sliding_window_size=inputs["WINDOW_SIZE"],
)
def _run_torch_ref(inputs):
extend_attention_fwd_torch(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_torch"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
inputs["WINDOW_SIZE"],
)
N_CTXS = [1024, 2048, 4096, 8192]
WINDOW_SIZES = [-1, 127, 256, 512]
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
PROVIDERS = ["torch", "triton"]
@tt.perf_report(
tt.Benchmark(
x_names=["N_CTX", "WINDOW_SIZE"],
x_vals=CONFIGS,
line_arg="provider",
line_vals=PROVIDERS,
line_names=PROVIDERS,
ylabel="Runtime (ms)",
plot_name="extend_attention_triton_vs_torch",
args={
"B": 32,
"H_Q": 64,
"H_KV": 8,
"D": 128,
"dtype": "bf16",
"device": "cuda",
"check_correctness": False,
"warmup": 25,
"rep": 100,
},
)
)
def bench(
N_CTX,
provider,
B,
H_Q,
H_KV,
D,
dtype,
device,
WINDOW_SIZE,
check_correctness,
warmup,
rep,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
dt = dtype_map[dtype]
inputs, _ = _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
)
if check_correctness and provider == "triton":
_run_triton(inputs)
_run_torch_ref(inputs)
torch.cuda.synchronize()
if not torch.allclose(
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
):
raise AssertionError("Mismatch between triton and torch reference.")
if provider == "triton":
ms = tt.do_bench(lambda: _run_triton(inputs), warmup=warmup, rep=rep)
elif provider == "torch":
ms = tt.do_bench(lambda: _run_torch_ref(inputs), warmup=warmup, rep=rep)
else:
raise ValueError(provider)
return ms
if __name__ == "__main__":
bench.run(print_data=True, show_plots=False)
## Download data
```
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
python3 gen_data.py --number 1000
```
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
```
```
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
```
###
```
# original
Accuracy: 0.940, latency: 332.83 s
# parallel encoding (no_adjust, offset = 1000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 3000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 0)
Accuracy: 0.520, latency: 238.46 s
# parallel encoding (adjust_cache)
Accuracy: 0.460, latency: 257.66 s
```
import argparse
import json
import re
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@sgl.function
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
s += prefix + "\n"
contexts = [body_0, body_1, body_2, body_3]
position_ids_offset = [i * 1000 for i in range(len(contexts))]
forks = s.fork(len(contexts), position_ids_offset)
forks += lambda i: contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "\n" + suffix
s += sgl.gen("answer", max_tokens=16)
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
arguments = []
labels = []
sum_src_indices = []
sum_dst_indices = []
for i in range(len(src_indices)):
for j in range(len(dst_percents)):
src_index = src_indices[i]
dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [
q
for q in query_indices
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
]
dst_index = query_indices[
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
]
label = line_obj["values"][dst_index]
body = line_obj["lines"][: src_index + 1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4
arguments.append(
{
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len :]),
"suffix": suffix,
}
)
labels.append(label)
sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index)
# Select backend
backend = select_sglang_backend(args)
tic = time.perf_counter()
states = line_retrieval.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
corrects = []
for i in range(len(arguments)):
output = states[i]["answer"]
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
label = labels[i]
# Try all numbers
findall = re.findall("\d+", output)
if not findall:
response_number = output
else:
for response_number in findall:
if response_number == label:
break
correct = response_number == label
corrects.append(correct)
# Log results
summary = (
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
f"Prompt len: {prompt_len}, "
f"Correct: {correct}, "
f"Label: {label}, Predicted: {response_number}, "
)
print(summary)
accuracy = np.mean(corrects)
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "line_retrieval",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"num_questions": len(arguments),
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
def main(args):
line_obj = json.load(open(args.data_path, "r"))
num_hoops = args.num_hoops
for src_index in args.src_index:
src_indices = [src_index]
num_queries = args.num_queries_per_src
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
parser.add_argument("--num-queries-per-src", type=int, default=10)
parser.add_argument("--num-hoops", type=int, default=1)
args = add_common_sglang_args_and_parse(parser)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment