Unverified Commit b1fb7e45 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[benchmark] add flashinfer_allreduce_fusion benchmark (#9937)

parent 1b2ff4fb
# FlashInfer Fused AllReduce + RMSNorm Benchmark
This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.
This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.
Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`
## Feature Overview
- Compare average execution time (ms) and calculate speedup ratios for the following paths:
- standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)
- flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes
- Optionally compare FP8/FP4 quantized fused paths with standard paths
- Use CUDA Graph capture and batch replay to reduce measurement noise
- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation
- Optionally export results in Markdown format
## Runtime Environment and Prerequisites
- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)
- Properly install/compile sglang along with sgl-kernel and custom operators
## Quick Start (Command Examples)
The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:
- Regular paths only (no quantization):
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP8 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP4 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- Larger hidden dimensions:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100
```
## Parameter Description
- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)
- `--hidden-dim`: Hidden dimension (default: 8192)
- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)
- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual")
- Mutually exclusive quantization options:
- `--no-quant`: No quantization testing
- `--quant-fp8`: Only FP8 quantization testing
- `--quant-fp4`: Only FP4 quantization testing
- `--quant-all`: Test all (default)
- FlashInfer related:
- `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)
- Runtime configuration:
- `--warmup`: Warmup count before graph capture and before graph replay (default 5)
- `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)
- `--output-file`: Save results as Markdown file (only rank0 takes effect)
## Output Example
Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:
```
================================================================================
Results: seq_len=1024, hidden_dim=1024
dtype=torch.bfloat16, residual=yes, quant_mode=none
================================================================================
Operation Time (ms) Speedup
--------------------------------------------------------------------------------
standard_allreduce_rmsnorm 0.024 0.98x
standard_allreduce_rmsnorm_native_compiled 0.023 baseline
flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x
flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x
```
If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.
## Important Notes and Recommendations
- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.
- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.
- FlashInfer:
- If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.
- The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.
- FP8/FP4:
- FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.
- FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.
- CUDA Graph:
- Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.
# Modified from https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100
"""
import argparse
import contextlib
import itertools
import logging
import os
import time
from typing import Optional
import torch # type: ignore
import torch.distributed as dist # type: ignore
from sglang.srt.distributed import get_tp_group, tensor_model_parallel_all_reduce
from sglang.srt.distributed.parallel_state import (
cleanup_dist_env_and_memory,
graph_capture,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.layernorm import RMSNorm # noqa
from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as SGLANG_FP8_DTYPE
from sglang.srt.layers.quantization.fp8_kernel import static_quant_fp8
try:
from sgl_kernel import fused_add_rmsnorm as SGL_FUSED_ADD_RMS_NORM
from sgl_kernel import rmsnorm as SGL_RMS_NORM
from sgl_kernel import scaled_fp4_quant as SGL_SCALED_FP4_QUANT
except Exception: # pragma: no cover - fallback on non-supported platforms
SGL_FUSED_ADD_RMS_NORM = None
SGL_RMS_NORM = None
SGL_SCALED_FP4_QUANT = None
FP8_DTYPE = SGLANG_FP8_DTYPE
logger = logging.getLogger(__name__)
# Try to import FlashInfer
try:
import flashinfer.comm as flashinfer_comm # type: ignore
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm = None
logger.warning(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except ImportError:
flashinfer_comm = None
logger.warning("FlashInfer not found, only benchmarking standard operations")
# Constants
MiB = 1024 * 1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: 64 * MiB, # 64MB
8: 64 * MiB, # 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR = None
def setup_flashinfer_workspace(
world_size: int,
rank: int,
hidden_dim: int,
max_token_num: int,
use_fp32_lamport: bool = False,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE_TENSOR
if flashinfer_comm is None:
return None, None
if world_size not in _FI_MAX_SIZES:
logger.warning("FlashInfer not supported for world size %s", world_size)
return None, None
try:
# Create IPC workspace
ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=world_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=get_tp_group().device_group,
use_fp32_lamport=use_fp32_lamport,
)
)
_FI_WORKSPACE_TENSOR = workspace_tensor
return ipc_handles, workspace_tensor
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
return None, None
def cleanup_flashinfer_workspace(ipc_handles):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or ipc_handles is None:
return
try:
group = get_tp_group().device_group
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
}
def flashinfer_fused_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
allreduce_params: "FlashInferFusedAllReduceParams",
use_oneshot: bool,
norm_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=None,
scale_factor=None,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
use_oneshot: bool = True,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=None,
layout_code=None,
scale_factor=scale_factor,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
allreduce_params: FlashInferFusedAllReduceParams,
quant_out: torch.Tensor,
use_oneshot: bool,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")
if norm_out is None:
norm_out = input_tensor
residual_out = residual
else:
residual_out = input_tensor
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=output_scale,
layout_code=None,
scale_factor=input_global_scale,
use_oneshot=use_oneshot,
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
def standard_allreduce_rmsnorm(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm operations."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm
if residual is not None:
# Fused add + RMS norm (in-place on allreduce_out)
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
rms.forward_native(allreduce_out, residual)
else:
# Just RMS norm
if SGL_RMS_NORM is not None:
_ = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
_ = rms.forward_native(allreduce_out)
def standard_allreduce_rmsnorm_fp8_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP8 quantization."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm + static FP8 quantization
if residual is not None:
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
quant_out, _ = static_quant_fp8(
allreduce_out, scale_factor, repeat_scale=False
)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
normed, _ = rms.forward_native(allreduce_out, residual)
quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)
return quant_out, residual
else:
if SGL_RMS_NORM is not None:
normed = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
normed = rms.forward_native(allreduce_out)
quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False)
return quant_out
def standard_allreduce_rmsnorm_fp4_quant(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rms_gamma: torch.Tensor,
rms_eps: float,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP4 quantization."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Then RMS norm
if residual is not None:
if SGL_FUSED_ADD_RMS_NORM is not None:
SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps)
quant_input = allreduce_out
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
quant_input, _ = rms.forward_native(allreduce_out, residual)
residual_out = residual
else:
if SGL_RMS_NORM is not None:
quant_input = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps)
else:
rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps)
rms.weight.data = rms_gamma
quant_input = rms.forward_native(allreduce_out)
residual_out = allreduce_out
# Finally FP4 quantization
if SGL_SCALED_FP4_QUANT is None:
raise RuntimeError("scaled_fp4_quant is not available on this platform")
quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)
if residual is not None:
return quant_res, residual_out, output_scale_res
else:
return quant_res, quant_input
def standard_allreduce_rmsnorm_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm operations using native RMSNorm forward."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
result = rmsnorm_layer.forward_native(allreduce_out, residual)
return result # Returns (norm_out, residual_out)
else:
result = rmsnorm_layer.forward_native(allreduce_out)
return result # Returns norm_out
def standard_allreduce_rmsnorm_fp8_quant_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP8 quantization using native implementations."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)
else:
norm_out = rmsnorm_layer.forward_native(allreduce_out)
residual_out = allreduce_out
# Apply native FP8 quantization
quant_out, _ = static_quant_fp8(norm_out, scale_factor, repeat_scale=False)
if residual is not None:
return quant_out, residual_out
else:
return quant_out
def standard_allreduce_rmsnorm_fp4_quant_native(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm."""
# All-reduce first
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
# Apply native RMSNorm
if residual is not None:
norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual)
quant_input = norm_out
else:
norm_out = rmsnorm_layer.forward_native(allreduce_out)
quant_input = norm_out
residual_out = allreduce_out
# Apply FP4 quantization (still using fused CUDA op as there's no native FP4)
if SGL_SCALED_FP4_QUANT is None:
raise RuntimeError("scaled_fp4_quant is not available on this platform")
quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale)
if residual is not None:
return quant_res, residual_out, output_scale_res
else:
return quant_res, norm_out
# Compiled versions of native functions
@torch.compile
def standard_allreduce_rmsnorm_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
norm_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm."""
return standard_allreduce_rmsnorm_native(
input_tensor, residual, rmsnorm_layer, norm_out
)
@torch.compile
def standard_allreduce_rmsnorm_fp8_quant_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
scale_factor: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm + FP8 quantization."""
return standard_allreduce_rmsnorm_fp8_quant_native(
input_tensor,
residual,
rmsnorm_layer,
scale_factor,
norm_out,
quant_out,
)
@torch.compile
def standard_allreduce_rmsnorm_fp4_quant_native_compiled(
input_tensor: torch.Tensor,
residual: Optional[torch.Tensor],
rmsnorm_layer: RMSNorm,
input_global_scale: torch.Tensor,
quant_out: torch.Tensor,
output_scale: torch.Tensor,
norm_out: Optional[torch.Tensor] = None,
):
"""Compiled version of standard allreduce + rmsnorm + FP4 quantization."""
return standard_allreduce_rmsnorm_fp4_quant_native(
input_tensor,
residual,
rmsnorm_layer,
input_global_scale,
quant_out,
output_scale,
norm_out,
)
def create_test_tensors(
seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True
):
"""Create test tensors for benchmarking."""
input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype)
residual = (
torch.randn_like(input_tensor)
if use_residual
else torch.zeros_like(input_tensor)
)
rms_gamma = torch.ones(hidden_dim, dtype=dtype)
norm_out = None if use_residual else torch.empty_like(input_tensor)
# Quantization scales
scale_fp8 = torch.tensor(1.0, dtype=torch.float32)
scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8)
fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)
return (
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
)
def benchmark_operation(
operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for _ in range(warmup):
operation_func(*args, **kwargs)
torch.cuda.synchronize()
# Create CUDA graph
graph = torch.cuda.CUDAGraph()
num_op_per_cudagraph = 10
# Use sglang's graph_capture to make tensor_model_parallel_all_reduce graph-safe
with graph_capture() as graph_capture_context:
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(num_op_per_cudagraph):
operation_func(*args, **kwargs)
# Graph warmup
torch.cuda.synchronize()
for _ in range(warmup):
graph.replay()
# Benchmark with CUDA graph
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(trials // num_op_per_cudagraph):
# operation_func(*args, **kwargs)
graph.replay()
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time_ms = ((end_time - start_time) / trials) * 1000
return avg_time_ms
def run_benchmarks(
seq_len: int,
hidden_dim: int,
dtype: torch.dtype,
use_residual: bool,
allreduce_params: Optional[FlashInferFusedAllReduceParams],
quant_mode: str = "all",
disable_oneshot: bool = False,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor,
norm_out,
residual,
rms_gamma,
scale_fp8,
quant_out_fp8,
scale_fp4,
fp4_quant_out,
fp4_output_scale,
) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual)
rms_eps = 1e-6
results = {}
# Create RMSNorm once for native benchmarks
rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps)
rmsnorm_layer.weight.data = rms_gamma
if quant_mode in ["all", "none"]:
# Standard AllReduce + RMSNorm
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
)
results["standard_allreduce_rmsnorm"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm failed: %s", e)
results["standard_allreduce_rmsnorm"] = float("inf")
# Standard AllReduce + RMSNorm Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
norm_out=norm_out,
)
results["standard_allreduce_rmsnorm_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms
except Exception as e:
logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e)
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf")
# FlashInfer Fused AllReduce + RMSNorm Two-shot
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
allreduce_params=allreduce_params,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e
)
results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf")
if quant_mode in ["all", "fp8_only"]:
# Standard AllReduce + RMSNorm + FP8 Quant
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp8_quant_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
# quant_fp8_layer removed in sglang version; static_quant_fp8 is used within the function
scale_factor=scale_fp8,
norm_out=norm_out,
quant_out=quant_out_fp8,
)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_fp8,
quant_out=quant_out_fp8,
allreduce_params=allreduce_params,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float(
"inf"
)
if quant_mode in ["all", "fp4_only"]:
# Standard AllReduce + RMSNorm + FP4 Quant
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant,
input_tensor,
norm_out=norm_out,
residual=residual,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
)
results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e)
results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf")
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
try:
time_ms = benchmark_operation(
standard_allreduce_rmsnorm_fp4_quant_native_compiled,
input_tensor,
residual=residual,
rmsnorm_layer=rmsnorm_layer,
input_global_scale=scale_fp4,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
norm_out=norm_out,
)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e)
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if flashinfer_comm is not None and allreduce_params is not None:
try:
if not disable_oneshot:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=True,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if flashinfer_comm is not None and allreduce_params is not None:
try:
time_ms = benchmark_operation(
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
input_tensor,
residual=residual,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
input_global_scale=scale_fp4,
allreduce_params=allreduce_params,
quant_out=fp4_quant_out,
output_scale=fp4_output_scale,
use_oneshot=False,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
time_ms
)
except Exception as e:
logger.error(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
e,
)
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
"inf"
)
return results
def prepare_results_with_speedups(results_dict):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results = []
# Determine the fastest baseline for each operation type
def get_fastest_baseline(op_name, results_dict):
"""Get the fastest baseline between standard and native_compiled versions."""
if "fp8_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp8_quant",
"standard_allreduce_rmsnorm_fp8_quant_native_compiled",
]
elif "fp4_quant" in op_name:
candidates = [
"standard_allreduce_rmsnorm_fp4_quant",
"standard_allreduce_rmsnorm_fp4_quant_native_compiled",
]
else:
candidates = [
"standard_allreduce_rmsnorm",
"standard_allreduce_rmsnorm_native_compiled",
]
# Find the fastest among available candidates
fastest_time = float("inf")
fastest_baseline = None
for candidate in candidates:
if (
candidate in results_dict
and results_dict[candidate] != float("inf")
and results_dict[candidate] < fastest_time
):
fastest_time = results_dict[candidate]
fastest_baseline = candidate
return fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping = {}
for op_name in results_dict:
if (
op_name.startswith("flashinfer_")
or op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
dynamic_baseline_mapping[op_name] = get_fastest_baseline(
op_name, results_dict
)
for op_name, time_ms in results_dict.items():
if time_ms == float("inf"):
speedup_str = "FAILED"
time_str = "FAILED"
else:
time_str = f"{time_ms:.3f}"
# Find the appropriate baseline for this operation
baseline_op = dynamic_baseline_mapping.get(op_name)
if baseline_op and baseline_op in results_dict:
baseline_time = results_dict[baseline_op]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
# For baseline operations, determine if this is the fastest baseline
if op_name.endswith("_native_compiled") or (
op_name.startswith("standard_")
and not op_name.endswith("_native_compiled")
):
fastest_baseline = get_fastest_baseline(op_name, results_dict)
if fastest_baseline == op_name:
speedup_str = "baseline"
else:
if fastest_baseline and fastest_baseline in results_dict:
baseline_time = results_dict[fastest_baseline]
if baseline_time != float("inf") and baseline_time > 0:
speedup = baseline_time / time_ms
speedup_str = f"{speedup:.2f}x"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
else:
speedup_str = "N/A"
prepared_results.append(
{
"operation": op_name,
"time_ms": time_ms,
"time_str": time_str,
"speedup_str": speedup_str,
}
)
return prepared_results
def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode):
"""Print benchmark results in a formatted table."""
print(f"\n{'=' * 80}")
print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}")
print(
f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, "
f"quant_mode={quant_mode}"
)
print(f"{'=' * 80}")
print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}")
print(f"{'-' * 80}")
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
if result["time_ms"] == float("inf"):
time_display = result["time_str"]
else:
time_display = f"{result['time_ms']:.3f}"
print(
f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}"
)
def format_results_markdown(
all_results: list[dict], world_size: int, args: argparse.Namespace
) -> str:
"""Format all benchmark results as markdown."""
markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results
**World Size:** {world_size}
**Hidden Dimension:** {args.hidden_dim}
**Warmup Iterations:** {args.warmup}
**Benchmark Trials:** {args.trials}
**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"}
---
"""
for result in all_results:
seq_len = result["seq_len"]
dtype = result["dtype"]
use_residual = result["use_residual"]
results_dict = result["results"]
residual_str = "with residual" if use_residual else "no residual"
markdown += f"""
## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str}
| Operation | Time (ms) | Speedup |
|-----------|-----------|---------|
"""
# Prepare results with speedup calculations
prepared_results = prepare_results_with_speedups(results_dict)
for result in prepared_results:
# Format operation name for better readability
formatted_op_name = result["operation"].replace("_", " ").title()
markdown += f"| {formatted_op_name} | {result['time_str']} |"
markdown += f"{result['speedup_str']} |\n"
markdown += "\n"
return markdown
def save_results_to_file(
all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if rank != 0:
return
if not all_results:
logger.warning("No results to save")
return
output_path = args.output_file
try:
markdown_content = format_results_markdown(all_results, world_size, args)
with open(output_path, "w") as f:
f.write(markdown_content)
except Exception as e:
logger.error("Failed to save results to file: %s", e)
def main():
parser = argparse.ArgumentParser(
description="Benchmark fused collective operations"
)
parser.add_argument(
"--seq-lens",
type=int,
nargs="+",
default=[128, 512, 1024, 2048],
help="Sequence lengths to test",
)
parser.add_argument(
"--hidden-dim", type=int, default=8192, help="Hidden dimension size"
)
parser.add_argument(
"--dtypes",
type=str,
nargs="+",
default=["bfloat16"],
choices=["float16", "bfloat16", "float32"],
help="Data types to test",
)
parser.add_argument(
"--no-residual",
action="store_true",
help="Skip residual connection tests",
)
# Quantization mode options (mutually exclusive with --no-quant)
quant_group = parser.add_mutually_exclusive_group()
quant_group.add_argument(
"--no-quant", action="store_true", help="Skip all quantization tests"
)
quant_group.add_argument(
"--quant-fp8", action="store_true", help="Only run FP8 quantization tests"
)
quant_group.add_argument(
"--quant-fp4", action="store_true", help="Only run FP4 quantization tests"
)
quant_group.add_argument(
"--quant-all",
action="store_true",
help="Run all quantization tests (default)",
)
parser.add_argument(
"--disable-oneshot",
action="store_true",
help="Disable oneshot mode for FlashInfer operations",
)
parser.add_argument(
"--warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--trials", type=int, default=20, help="Number of benchmark trials"
)
parser.add_argument(
"--output-file",
type=str,
help="""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
""",
)
args = parser.parse_args()
# Check if running with torchrun (required for collective operations)
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=rank,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Validate world size (must be > 1 for collective operations)
if world_size <= 1:
raise ValueError(
"World size must be > 1 for collective operations benchmarking. "
f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1."
)
# Determine quantization mode
if args.no_quant:
quant_mode = "none"
elif args.quant_fp8:
quant_mode = "fp8_only"
elif args.quant_fp4:
quant_mode = "fp4_only"
else: # args.quant_all or default
quant_mode = "all"
if rank == 0:
logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank)
logger.info("Quantization mode: %s", quant_mode)
if flashinfer_comm is not None:
oneshot_status = "enabled" if not args.disable_oneshot else "disabled"
logger.info(
"FlashInfer available - will benchmark fused operations (oneshot: %s)",
oneshot_status,
)
else:
logger.info(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtypes = [dtype_map[dt] for dt in args.dtypes]
# Test configurations
residual_options = [True] if not args.no_residual else [False]
if not args.no_residual:
residual_options.append(False)
configs = list(itertools.product(args.seq_lens, dtypes, residual_options))
# Setup FlashInfer workspace if available
ipc_handles = None
allreduce_params = None
if flashinfer_comm is not None:
# Use the largest hidden dimension for workspace setup
max_num_token = _FI_MAX_SIZES.get(world_size) // (
args.hidden_dim * world_size * 2
)
ipc_handles, workspace_tensor = setup_flashinfer_workspace(
world_size, rank, args.hidden_dim, max_num_token
)
if workspace_tensor is not None:
allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=world_size,
max_token_num=max_num_token,
)
# Collect all results for markdown export
all_results = []
try:
# Run benchmarks
for seq_len, dtype, use_residual in configs:
if rank == 0:
logger.info(
"\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s",
seq_len,
args.hidden_dim,
dtype,
use_residual,
)
results = run_benchmarks(
seq_len,
args.hidden_dim,
dtype,
use_residual,
allreduce_params,
quant_mode=quant_mode,
disable_oneshot=args.disable_oneshot,
)
# Store results for markdown export
if rank == 0:
all_results.append(
{
"seq_len": seq_len,
"hidden_dim": args.hidden_dim,
"dtype": str(dtype).replace("torch.", ""),
"use_residual": use_residual,
"quant_mode": quant_mode,
"results": results,
}
)
print_results(
results,
seq_len,
args.hidden_dim,
dtype,
use_residual,
quant_mode,
)
# Save results to markdown file
if args.output_file and rank == 0:
save_results_to_file(all_results, world_size, args, rank)
finally:
# Cleanup
if ipc_handles is not None:
cleanup_flashinfer_workspace(ipc_handles)
with contextlib.suppress(Exception):
dist.barrier()
cleanup_dist_env_and_memory(shutdown_ray=False)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment