# Modified from https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py """ Benchmark for FlashInfer fused collective operations vs standard operations. This benchmark compares: 1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) 2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations Usage with torchrun: torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100 torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100 torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 1024 --seq-len 512 1024 2048 4096 --trials 100 torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --no-quant --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100 torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp8 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100 torchrun --nproc_per_node=2 benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py --quant-fp4 --hidden-dim 4096 --seq-len 512 1024 2048 4096 --trials 100 """ import argparse import contextlib import itertools import logging import os import time from typing import Optional import torch # type: ignore import torch.distributed as dist # type: ignore from sglang.srt.distributed import get_tp_group, tensor_model_parallel_all_reduce from sglang.srt.distributed.parallel_state import ( cleanup_dist_env_and_memory, graph_capture, init_distributed_environment, initialize_model_parallel, ) from sglang.srt.layers.layernorm import RMSNorm # noqa from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype as SGLANG_FP8_DTYPE from sglang.srt.layers.quantization.fp8_kernel import static_quant_fp8 try: from sgl_kernel import fused_add_rmsnorm as SGL_FUSED_ADD_RMS_NORM from sgl_kernel import rmsnorm as SGL_RMS_NORM from sgl_kernel import scaled_fp4_quant as SGL_SCALED_FP4_QUANT except Exception: # pragma: no cover - fallback on non-supported platforms SGL_FUSED_ADD_RMS_NORM = None SGL_RMS_NORM = None SGL_SCALED_FP4_QUANT = None FP8_DTYPE = SGLANG_FP8_DTYPE logger = logging.getLogger(__name__) # Try to import FlashInfer try: import flashinfer.comm as flashinfer_comm # type: ignore if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): flashinfer_comm = None logger.warning( "FlashInfer comm module found but missing trtllm_allreduce_fusion" ) except ImportError: flashinfer_comm = None logger.warning("FlashInfer not found, only benchmarking standard operations") # Constants MiB = 1024 * 1024 # FlashInfer max sizes per world size # Enable 64MB for 2, 4, 8 world sizes to verify large input sizes # use --disable-oneshot to disable oneshot mode for very large input sizes _FI_MAX_SIZES = { 2: 64 * MiB, # 64MB 4: 64 * MiB, # 64MB 8: 64 * MiB, # 64MB } # Global workspace tensor for FlashInfer _FI_WORKSPACE_TENSOR = None def setup_flashinfer_workspace( world_size: int, rank: int, hidden_dim: int, max_token_num: int, use_fp32_lamport: bool = False, ): """Setup FlashInfer workspace for fused allreduce operations.""" global _FI_WORKSPACE_TENSOR if flashinfer_comm is None: return None, None if world_size not in _FI_MAX_SIZES: logger.warning("FlashInfer not supported for world size %s", world_size) return None, None try: # Create IPC workspace ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=world_size, max_token_num=max_token_num, hidden_dim=hidden_dim, group=get_tp_group().device_group, use_fp32_lamport=use_fp32_lamport, ) ) _FI_WORKSPACE_TENSOR = workspace_tensor return ipc_handles, workspace_tensor except Exception as e: logger.error("Failed to setup FlashInfer workspace: %s", e) return None, None def cleanup_flashinfer_workspace(ipc_handles): """Cleanup FlashInfer workspace.""" if flashinfer_comm is None or ipc_handles is None: return try: group = get_tp_group().device_group flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) except Exception as e: logger.error("Failed to cleanup FlashInfer workspace: %s", e) class FlashInferFusedAllReduceParams: """Parameters for FlashInfer fused allreduce operations.""" def __init__( self, rank: int, world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, ): self.rank = rank self.world_size = world_size self.use_fp32_lamport = use_fp32_lamport self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True self.max_token_num = max_token_num def get_trtllm_fused_allreduce_kwargs(self): return { "world_rank": self.rank, "world_size": self.world_size, "launch_with_pdl": self.launch_with_pdl, "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, } def flashinfer_fused_allreduce_rmsnorm( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rms_gamma: torch.Tensor, rms_eps: float, allreduce_params: "FlashInferFusedAllReduceParams", use_oneshot: bool, norm_out: Optional[torch.Tensor] = None, ): """FlashInfer fused allreduce + rmsnorm operation.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: norm_out = input_tensor residual_out = residual else: residual_out = input_tensor flashinfer_comm.trtllm_allreduce_fusion( allreduce_in=input_tensor, token_num=input_tensor.shape[0], residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, hidden_dim=input_tensor.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, allreduce_out=None, quant_out=None, scale_out=None, layout_code=None, scale_factor=None, use_oneshot=use_oneshot, **allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) def flashinfer_fused_allreduce_rmsnorm_fp8_quant( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rms_gamma: torch.Tensor, rms_eps: float, scale_factor: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, use_oneshot: bool = True, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, ): """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: norm_out = input_tensor residual_out = residual else: residual_out = input_tensor flashinfer_comm.trtllm_allreduce_fusion( allreduce_in=input_tensor, token_num=input_tensor.shape[0], residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, hidden_dim=input_tensor.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, allreduce_out=None, quant_out=quant_out, scale_out=None, layout_code=None, scale_factor=scale_factor, use_oneshot=use_oneshot, **allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) def flashinfer_fused_allreduce_rmsnorm_fp4_quant( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rms_gamma: torch.Tensor, rms_eps: float, input_global_scale: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, quant_out: torch.Tensor, use_oneshot: bool, output_scale: torch.Tensor, norm_out: Optional[torch.Tensor] = None, ): """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: norm_out = input_tensor residual_out = residual else: residual_out = input_tensor flashinfer_comm.trtllm_allreduce_fusion( allreduce_in=input_tensor, token_num=input_tensor.shape[0], residual_in=residual, residual_out=residual_out, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, hidden_dim=input_tensor.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, allreduce_out=None, quant_out=quant_out, scale_out=output_scale, layout_code=None, scale_factor=input_global_scale, use_oneshot=use_oneshot, **allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) def standard_allreduce_rmsnorm( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rms_gamma: torch.Tensor, rms_eps: float, norm_out: Optional[torch.Tensor] = None, ): """Standard allreduce + rmsnorm operations.""" # All-reduce first allreduce_out = tensor_model_parallel_all_reduce(input_tensor) # Then RMS norm if residual is not None: # Fused add + RMS norm (in-place on allreduce_out) if SGL_FUSED_ADD_RMS_NORM is not None: SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps) else: rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) rms.weight.data = rms_gamma rms.forward_native(allreduce_out, residual) else: # Just RMS norm if SGL_RMS_NORM is not None: _ = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps) else: rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) rms.weight.data = rms_gamma _ = rms.forward_native(allreduce_out) def standard_allreduce_rmsnorm_fp8_quant( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rms_gamma: torch.Tensor, rms_eps: float, scale_factor: torch.Tensor, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, ): """Standard allreduce + rmsnorm + FP8 quantization.""" # All-reduce first allreduce_out = tensor_model_parallel_all_reduce(input_tensor) # Then RMS norm + static FP8 quantization if residual is not None: if SGL_FUSED_ADD_RMS_NORM is not None: SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps) quant_out, _ = static_quant_fp8( allreduce_out, scale_factor, repeat_scale=False ) else: rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) rms.weight.data = rms_gamma normed, _ = rms.forward_native(allreduce_out, residual) quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False) return quant_out, residual else: if SGL_RMS_NORM is not None: normed = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps) else: rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) rms.weight.data = rms_gamma normed = rms.forward_native(allreduce_out) quant_out, _ = static_quant_fp8(normed, scale_factor, repeat_scale=False) return quant_out def standard_allreduce_rmsnorm_fp4_quant( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rms_gamma: torch.Tensor, rms_eps: float, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, norm_out: Optional[torch.Tensor] = None, ): """Standard allreduce + rmsnorm + FP4 quantization.""" # All-reduce first allreduce_out = tensor_model_parallel_all_reduce(input_tensor) # Then RMS norm if residual is not None: if SGL_FUSED_ADD_RMS_NORM is not None: SGL_FUSED_ADD_RMS_NORM(allreduce_out, residual, rms_gamma, rms_eps) quant_input = allreduce_out else: rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) rms.weight.data = rms_gamma quant_input, _ = rms.forward_native(allreduce_out, residual) residual_out = residual else: if SGL_RMS_NORM is not None: quant_input = SGL_RMS_NORM(allreduce_out, rms_gamma, rms_eps) else: rms = RMSNorm(allreduce_out.shape[-1], eps=rms_eps) rms.weight.data = rms_gamma quant_input = rms.forward_native(allreduce_out) residual_out = allreduce_out # Finally FP4 quantization if SGL_SCALED_FP4_QUANT is None: raise RuntimeError("scaled_fp4_quant is not available on this platform") quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale) if residual is not None: return quant_res, residual_out, output_scale_res else: return quant_res, quant_input def standard_allreduce_rmsnorm_native( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rmsnorm_layer: RMSNorm, norm_out: Optional[torch.Tensor] = None, ): """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" # All-reduce first allreduce_out = tensor_model_parallel_all_reduce(input_tensor) # Apply native RMSNorm if residual is not None: result = rmsnorm_layer.forward_native(allreduce_out, residual) return result # Returns (norm_out, residual_out) else: result = rmsnorm_layer.forward_native(allreduce_out) return result # Returns norm_out def standard_allreduce_rmsnorm_fp8_quant_native( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rmsnorm_layer: RMSNorm, scale_factor: torch.Tensor, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, ): """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" # All-reduce first allreduce_out = tensor_model_parallel_all_reduce(input_tensor) # Apply native RMSNorm if residual is not None: norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) else: norm_out = rmsnorm_layer.forward_native(allreduce_out) residual_out = allreduce_out # Apply native FP8 quantization quant_out, _ = static_quant_fp8(norm_out, scale_factor, repeat_scale=False) if residual is not None: return quant_out, residual_out else: return quant_out def standard_allreduce_rmsnorm_fp4_quant_native( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rmsnorm_layer: RMSNorm, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, norm_out: Optional[torch.Tensor] = None, ): """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" # All-reduce first allreduce_out = tensor_model_parallel_all_reduce(input_tensor) # Apply native RMSNorm if residual is not None: norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) quant_input = norm_out else: norm_out = rmsnorm_layer.forward_native(allreduce_out) quant_input = norm_out residual_out = allreduce_out # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) if SGL_SCALED_FP4_QUANT is None: raise RuntimeError("scaled_fp4_quant is not available on this platform") quant_res, output_scale_res = SGL_SCALED_FP4_QUANT(quant_input, input_global_scale) if residual is not None: return quant_res, residual_out, output_scale_res else: return quant_res, norm_out # Compiled versions of native functions @torch.compile def standard_allreduce_rmsnorm_native_compiled( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rmsnorm_layer: RMSNorm, norm_out: Optional[torch.Tensor] = None, ): """Compiled version of standard allreduce + rmsnorm.""" return standard_allreduce_rmsnorm_native( input_tensor, residual, rmsnorm_layer, norm_out ) @torch.compile def standard_allreduce_rmsnorm_fp8_quant_native_compiled( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rmsnorm_layer: RMSNorm, scale_factor: torch.Tensor, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, ): """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" return standard_allreduce_rmsnorm_fp8_quant_native( input_tensor, residual, rmsnorm_layer, scale_factor, norm_out, quant_out, ) @torch.compile def standard_allreduce_rmsnorm_fp4_quant_native_compiled( input_tensor: torch.Tensor, residual: Optional[torch.Tensor], rmsnorm_layer: RMSNorm, input_global_scale: torch.Tensor, quant_out: torch.Tensor, output_scale: torch.Tensor, norm_out: Optional[torch.Tensor] = None, ): """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" return standard_allreduce_rmsnorm_fp4_quant_native( input_tensor, residual, rmsnorm_layer, input_global_scale, quant_out, output_scale, norm_out, ) def create_test_tensors( seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True ): """Create test tensors for benchmarking.""" input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) residual = ( torch.randn_like(input_tensor) if use_residual else torch.zeros_like(input_tensor) ) rms_gamma = torch.ones(hidden_dim, dtype=dtype) norm_out = None if use_residual else torch.empty_like(input_tensor) # Quantization scales scale_fp8 = torch.tensor(1.0, dtype=torch.float32) scale_fp4 = torch.tensor(1.0, dtype=torch.float32) quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) return ( input_tensor, norm_out, residual, rms_gamma, scale_fp8, quant_out_fp8, scale_fp4, fp4_quant_out, fp4_output_scale, ) def benchmark_operation( operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs ): """Benchmark a single operation using CUDA graphs.""" # Warmup before graph capture for _ in range(warmup): operation_func(*args, **kwargs) torch.cuda.synchronize() # Create CUDA graph graph = torch.cuda.CUDAGraph() num_op_per_cudagraph = 10 # Use sglang's graph_capture to make tensor_model_parallel_all_reduce graph-safe with graph_capture() as graph_capture_context: with torch.cuda.graph(graph, stream=graph_capture_context.stream): for _ in range(num_op_per_cudagraph): operation_func(*args, **kwargs) # Graph warmup torch.cuda.synchronize() for _ in range(warmup): graph.replay() # Benchmark with CUDA graph torch.cuda.synchronize() start_time = time.perf_counter() for _ in range(trials // num_op_per_cudagraph): # operation_func(*args, **kwargs) graph.replay() torch.cuda.synchronize() end_time = time.perf_counter() avg_time_ms = ((end_time - start_time) / trials) * 1000 return avg_time_ms def run_benchmarks( seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool, allreduce_params: Optional[FlashInferFusedAllReduceParams], quant_mode: str = "all", disable_oneshot: bool = False, ): """Run all benchmarks for given configuration. Args: quant_mode: "none", "fp8_only", "fp4_only", or "all" """ ( input_tensor, norm_out, residual, rms_gamma, scale_fp8, quant_out_fp8, scale_fp4, fp4_quant_out, fp4_output_scale, ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) rms_eps = 1e-6 results = {} # Create RMSNorm once for native benchmarks rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps) rmsnorm_layer.weight.data = rms_gamma if quant_mode in ["all", "none"]: # Standard AllReduce + RMSNorm try: time_ms = benchmark_operation( standard_allreduce_rmsnorm, input_tensor, norm_out=norm_out, residual=residual, rms_gamma=rms_gamma, rms_eps=rms_eps, ) results["standard_allreduce_rmsnorm"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm failed: %s", e) results["standard_allreduce_rmsnorm"] = float("inf") # Standard AllReduce + RMSNorm Native Compiled try: time_ms = benchmark_operation( standard_allreduce_rmsnorm_native_compiled, input_tensor, residual=residual, rmsnorm_layer=rmsnorm_layer, norm_out=norm_out, ) results["standard_allreduce_rmsnorm_native_compiled"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") # FlashInfer Fused AllReduce + RMSNorm Oneshot if flashinfer_comm is not None and allreduce_params is not None: try: if not disable_oneshot: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm, input_tensor, residual=residual, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, allreduce_params=allreduce_params, use_oneshot=True, ) results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms except Exception as e: logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") # FlashInfer Fused AllReduce + RMSNorm Two-shot try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm, input_tensor, residual=residual, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, allreduce_params=allreduce_params, use_oneshot=False, ) results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms except Exception as e: logger.error( "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e ) results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") if quant_mode in ["all", "fp8_only"]: # Standard AllReduce + RMSNorm + FP8 Quant try: time_ms = benchmark_operation( standard_allreduce_rmsnorm_fp8_quant, input_tensor, norm_out=norm_out, residual=residual, rms_gamma=rms_gamma, rms_eps=rms_eps, scale_factor=scale_fp8, quant_out=quant_out_fp8, ) results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled try: time_ms = benchmark_operation( standard_allreduce_rmsnorm_fp8_quant_native_compiled, input_tensor, residual=residual, rmsnorm_layer=rmsnorm_layer, # quant_fp8_layer removed in sglang version; static_quant_fp8 is used within the function scale_factor=scale_fp8, norm_out=norm_out, quant_out=quant_out_fp8, ) results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( "inf" ) # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: try: if not disable_oneshot: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp8_quant, input_tensor, norm_out=norm_out, residual=residual, rms_gamma=rms_gamma, rms_eps=rms_eps, scale_factor=scale_fp8, quant_out=quant_out_fp8, allreduce_params=allreduce_params, use_oneshot=True, ) results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( time_ms ) except Exception as e: logger.error( "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", e, ) results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( "inf" ) # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp8_quant, input_tensor, norm_out=norm_out, residual=residual, rms_gamma=rms_gamma, rms_eps=rms_eps, scale_factor=scale_fp8, quant_out=quant_out_fp8, allreduce_params=allreduce_params, use_oneshot=False, ) results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( time_ms ) except Exception as e: logger.error( "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", e, ) results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( "inf" ) if quant_mode in ["all", "fp4_only"]: # Standard AllReduce + RMSNorm + FP4 Quant try: time_ms = benchmark_operation( standard_allreduce_rmsnorm_fp4_quant, input_tensor, norm_out=norm_out, residual=residual, rms_gamma=rms_gamma, rms_eps=rms_eps, input_global_scale=scale_fp4, quant_out=fp4_quant_out, output_scale=fp4_output_scale, ) results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled try: time_ms = benchmark_operation( standard_allreduce_rmsnorm_fp4_quant_native_compiled, input_tensor, residual=residual, rmsnorm_layer=rmsnorm_layer, input_global_scale=scale_fp4, quant_out=fp4_quant_out, output_scale=fp4_output_scale, norm_out=norm_out, ) results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms except Exception as e: logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( "inf" ) # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: try: if not disable_oneshot: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp4_quant, input_tensor, residual=residual, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, input_global_scale=scale_fp4, allreduce_params=allreduce_params, quant_out=fp4_quant_out, output_scale=fp4_output_scale, use_oneshot=True, ) results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( time_ms ) except Exception as e: logger.error( "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", e, ) results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( "inf" ) # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot if flashinfer_comm is not None and allreduce_params is not None: try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp4_quant, input_tensor, residual=residual, norm_out=norm_out, rms_gamma=rms_gamma, rms_eps=rms_eps, input_global_scale=scale_fp4, allreduce_params=allreduce_params, quant_out=fp4_quant_out, output_scale=fp4_output_scale, use_oneshot=False, ) results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( time_ms ) except Exception as e: logger.error( "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", e, ) results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( "inf" ) return results def prepare_results_with_speedups(results_dict): """Prepare results with speedup calculations based on dynamic baseline selection.""" prepared_results = [] # Determine the fastest baseline for each operation type def get_fastest_baseline(op_name, results_dict): """Get the fastest baseline between standard and native_compiled versions.""" if "fp8_quant" in op_name: candidates = [ "standard_allreduce_rmsnorm_fp8_quant", "standard_allreduce_rmsnorm_fp8_quant_native_compiled", ] elif "fp4_quant" in op_name: candidates = [ "standard_allreduce_rmsnorm_fp4_quant", "standard_allreduce_rmsnorm_fp4_quant_native_compiled", ] else: candidates = [ "standard_allreduce_rmsnorm", "standard_allreduce_rmsnorm_native_compiled", ] # Find the fastest among available candidates fastest_time = float("inf") fastest_baseline = None for candidate in candidates: if ( candidate in results_dict and results_dict[candidate] != float("inf") and results_dict[candidate] < fastest_time ): fastest_time = results_dict[candidate] fastest_baseline = candidate return fastest_baseline # Create dynamic baseline mapping dynamic_baseline_mapping = {} for op_name in results_dict: if ( op_name.startswith("flashinfer_") or op_name.startswith("standard_") and not op_name.endswith("_native_compiled") ): dynamic_baseline_mapping[op_name] = get_fastest_baseline( op_name, results_dict ) for op_name, time_ms in results_dict.items(): if time_ms == float("inf"): speedup_str = "FAILED" time_str = "FAILED" else: time_str = f"{time_ms:.3f}" # Find the appropriate baseline for this operation baseline_op = dynamic_baseline_mapping.get(op_name) if baseline_op and baseline_op in results_dict: baseline_time = results_dict[baseline_op] if baseline_time != float("inf") and baseline_time > 0: speedup = baseline_time / time_ms speedup_str = f"{speedup:.2f}x" else: speedup_str = "N/A" else: # For baseline operations, determine if this is the fastest baseline if op_name.endswith("_native_compiled") or ( op_name.startswith("standard_") and not op_name.endswith("_native_compiled") ): fastest_baseline = get_fastest_baseline(op_name, results_dict) if fastest_baseline == op_name: speedup_str = "baseline" else: if fastest_baseline and fastest_baseline in results_dict: baseline_time = results_dict[fastest_baseline] if baseline_time != float("inf") and baseline_time > 0: speedup = baseline_time / time_ms speedup_str = f"{speedup:.2f}x" else: speedup_str = "N/A" else: speedup_str = "N/A" else: speedup_str = "N/A" prepared_results.append( { "operation": op_name, "time_ms": time_ms, "time_str": time_str, "speedup_str": speedup_str, } ) return prepared_results def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): """Print benchmark results in a formatted table.""" print(f"\n{'=' * 80}") print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") print( f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " f"quant_mode={quant_mode}" ) print(f"{'=' * 80}") print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") print(f"{'-' * 80}") # Prepare results with speedup calculations prepared_results = prepare_results_with_speedups(results_dict) for result in prepared_results: if result["time_ms"] == float("inf"): time_display = result["time_str"] else: time_display = f"{result['time_ms']:.3f}" print( f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" ) def format_results_markdown( all_results: list[dict], world_size: int, args: argparse.Namespace ) -> str: """Format all benchmark results as markdown.""" markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results **World Size:** {world_size} **Hidden Dimension:** {args.hidden_dim} **Warmup Iterations:** {args.warmup} **Benchmark Trials:** {args.trials} **Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} --- """ for result in all_results: seq_len = result["seq_len"] dtype = result["dtype"] use_residual = result["use_residual"] results_dict = result["results"] residual_str = "with residual" if use_residual else "no residual" markdown += f""" ## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} | Operation | Time (ms) | Speedup | |-----------|-----------|---------| """ # Prepare results with speedup calculations prepared_results = prepare_results_with_speedups(results_dict) for result in prepared_results: # Format operation name for better readability formatted_op_name = result["operation"].replace("_", " ").title() markdown += f"| {formatted_op_name} | {result['time_str']} |" markdown += f"{result['speedup_str']} |\n" markdown += "\n" return markdown def save_results_to_file( all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int ): """Save benchmark results to markdown file (only on rank 0).""" if rank != 0: return if not all_results: logger.warning("No results to save") return output_path = args.output_file try: markdown_content = format_results_markdown(all_results, world_size, args) with open(output_path, "w") as f: f.write(markdown_content) except Exception as e: logger.error("Failed to save results to file: %s", e) def main(): parser = argparse.ArgumentParser( description="Benchmark fused collective operations" ) parser.add_argument( "--seq-lens", type=int, nargs="+", default=[128, 512, 1024, 2048], help="Sequence lengths to test", ) parser.add_argument( "--hidden-dim", type=int, default=8192, help="Hidden dimension size" ) parser.add_argument( "--dtypes", type=str, nargs="+", default=["bfloat16"], choices=["float16", "bfloat16", "float32"], help="Data types to test", ) parser.add_argument( "--no-residual", action="store_true", help="Skip residual connection tests", ) # Quantization mode options (mutually exclusive with --no-quant) quant_group = parser.add_mutually_exclusive_group() quant_group.add_argument( "--no-quant", action="store_true", help="Skip all quantization tests" ) quant_group.add_argument( "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" ) quant_group.add_argument( "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" ) quant_group.add_argument( "--quant-all", action="store_true", help="Run all quantization tests (default)", ) parser.add_argument( "--disable-oneshot", action="store_true", help="Disable oneshot mode for FlashInfer operations", ) parser.add_argument( "--warmup", type=int, default=5, help="Number of warmup iterations" ) parser.add_argument( "--trials", type=int, default=20, help="Number of benchmark trials" ) parser.add_argument( "--output-file", type=str, help="""Output file path for markdown results (default: benchmark_results_.md) """, ) args = parser.parse_args() # Check if running with torchrun (required for collective operations) if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: raise RuntimeError( "Must run with torchrun for distributed benchmarking. " "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" ) # Initialize distributed environment rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) torch.set_default_device(device) init_distributed_environment( world_size=world_size, rank=rank, local_rank=rank, backend="nccl", ) initialize_model_parallel(tensor_model_parallel_size=world_size) # Validate world size (must be > 1 for collective operations) if world_size <= 1: raise ValueError( "World size must be > 1 for collective operations benchmarking. " f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." ) # Determine quantization mode if args.no_quant: quant_mode = "none" elif args.quant_fp8: quant_mode = "fp8_only" elif args.quant_fp4: quant_mode = "fp4_only" else: # args.quant_all or default quant_mode = "all" if rank == 0: logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) logger.info("Quantization mode: %s", quant_mode) if flashinfer_comm is not None: oneshot_status = "enabled" if not args.disable_oneshot else "disabled" logger.info( "FlashInfer available - will benchmark fused operations (oneshot: %s)", oneshot_status, ) else: logger.info( "FlashInfer not available - only benchmarking standard operations" ) # Convert dtype strings to torch dtypes dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } dtypes = [dtype_map[dt] for dt in args.dtypes] # Test configurations residual_options = [True] if not args.no_residual else [False] if not args.no_residual: residual_options.append(False) configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) # Setup FlashInfer workspace if available ipc_handles = None allreduce_params = None if flashinfer_comm is not None: # Use the largest hidden dimension for workspace setup max_num_token = _FI_MAX_SIZES.get(world_size) // ( args.hidden_dim * world_size * 2 ) ipc_handles, workspace_tensor = setup_flashinfer_workspace( world_size, rank, args.hidden_dim, max_num_token ) if workspace_tensor is not None: allreduce_params = FlashInferFusedAllReduceParams( rank=rank, world_size=world_size, max_token_num=max_num_token, ) # Collect all results for markdown export all_results = [] try: # Run benchmarks for seq_len, dtype, use_residual in configs: if rank == 0: logger.info( "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", seq_len, args.hidden_dim, dtype, use_residual, ) results = run_benchmarks( seq_len, args.hidden_dim, dtype, use_residual, allreduce_params, quant_mode=quant_mode, disable_oneshot=args.disable_oneshot, ) # Store results for markdown export if rank == 0: all_results.append( { "seq_len": seq_len, "hidden_dim": args.hidden_dim, "dtype": str(dtype).replace("torch.", ""), "use_residual": use_residual, "quant_mode": quant_mode, "results": results, } ) print_results( results, seq_len, args.hidden_dim, dtype, use_residual, quant_mode, ) # Save results to markdown file if args.output_file and rank == 0: save_results_to_file(all_results, world_size, args, rank) finally: # Cleanup if ipc_handles is not None: cleanup_flashinfer_workspace(ipc_handles) with contextlib.suppress(Exception): dist.barrier() cleanup_dist_env_and_memory(shutdown_ray=False) if __name__ == "__main__": main()