benchmark_fused_collective.py 39.1 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Benchmark for FlashInfer fused collective operations vs standard operations.

This benchmark compares:
8
9
10
11
12
1. FlashInfer's allreduce_fusion with trtllm backend
   (fused allreduce + rmsnorm + optional FP8/FP4 quant)
2. FlashInfer's allreduce_fusion with mnnvl backend
   (fused allreduce + rmsnorm only, no quantization support)
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

Usage with torchrun:
    torchrun --nproc_per_node=2 benchmark_fused_collective.py

"""

import argparse
import itertools
import os
import time

import pandas as pd
import torch  # type: ignore
import torch.distributed as dist  # type: ignore

28
from vllm._custom_ops import create_fp4_output_tensors
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (
    tensor_model_parallel_all_reduce,
)
from vllm.distributed.parallel_state import (
    graph_capture,
    init_distributed_environment,
    initialize_model_parallel,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm  # noqa
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8  # noqa
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape  # noqa
from vllm.platforms import current_platform  # noqa

RMS_NORM_OP = torch.ops._C.rms_norm
FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm
RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = (
    torch.ops._C.fused_add_rms_norm_static_fp8_quant
)
50
SCALED_FP4_QUANT_OUT_OP = torch.ops._C.scaled_fp4_quant.out
51
52
53
54

logger = init_logger(__name__)

# Try to import FlashInfer
55
TorchDistBackend = None
56
57
try:
    import flashinfer.comm as flashinfer_comm  # type: ignore
58
59
60
    from flashinfer.comm.mnnvl import (  # type: ignore
        TorchDistBackend,
    )
61

62
63
64
65
    if not (
        hasattr(flashinfer_comm, "allreduce_fusion")
        and hasattr(flashinfer_comm, "create_allreduce_fusion_workspace")
    ):
66
        flashinfer_comm = None
67
        logger.warning("FlashInfer comm module found but missing allreduce_fusion API")
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
except ImportError:
    flashinfer_comm = None
    logger.warning("FlashInfer not found, only benchmarking standard operations")

# Constants
FP8_DTYPE = current_platform.fp8_dtype()
MiB = 1024 * 1024

# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES = {
    2: 64 * MiB,  # 64MB
    4: 64 * MiB,  # 64MB
    8: 64 * MiB,  # 64MB
}

85
86
87
88
89
# Global workspace tensors for FlashInfer (keyed by backend name)
_FI_WORKSPACES: dict = {}

# Backends to benchmark
FLASHINFER_BACKENDS = ["trtllm", "mnnvl"]
90
91
92


def setup_flashinfer_workspace(
93
    backend: str,
94
95
96
97
    world_size: int,
    rank: int,
    hidden_dim: int,
    max_token_num: int,
98
    dtype: torch.dtype,
99
100
):
    """Setup FlashInfer workspace for fused allreduce operations."""
101
    global FI_WORKSPACES
102
103

    if flashinfer_comm is None:
104
        return None
105
106
107

    if world_size not in _FI_MAX_SIZES:
        logger.warning("FlashInfer not supported for world size %s", world_size)
108
        return None
109
110

    try:
111
112
113
114
        kwargs = {}
        if TorchDistBackend is not None:
            kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD)

115
        workspace = flashinfer_comm.create_allreduce_fusion_workspace(
116
            backend=backend,
117
118
119
120
121
            world_size=world_size,
            rank=rank,
            max_token_num=max_token_num,
            hidden_dim=hidden_dim,
            dtype=dtype,
122
            **kwargs,
123
124
        )

125
        _FI_WORKSPACES[backend] = workspace
126
        return workspace
127
    except Exception as e:
128
129
130
        logger.error(
            "Failed to setup FlashInfer workspace (backend=%s): %s", backend, e
        )
131
        return None
132
133


134
135
136
def cleanup_flashinfer_workspaces():
    """Cleanup all FlashInfer workspaces."""
    if flashinfer_comm is None:
137
138
        return

139
140
141
142
143
144
145
146
147
148
    for backend, workspace in _FI_WORKSPACES.items():
        try:
            workspace.destroy()
        except Exception as e:
            logger.error(
                "Failed to cleanup FlashInfer workspace (backend=%s): %s",
                backend,
                e,
            )
    _FI_WORKSPACES.clear()
149
150
151
152
153
154
155
156
157
158
159
160
161


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        max_token_num: int = 1024,
    ):
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.max_token_num = max_token_num

162
    def get_flashinfer_fused_allreduce_kwargs(self):
163
164
165
166
167
168
169
170
171
172
173
174
        return {
            "launch_with_pdl": self.launch_with_pdl,
            "fp32_acc": self.fp32_acc,
        }


def flashinfer_fused_allreduce_rmsnorm(
    input_tensor: torch.Tensor,
    residual: torch.Tensor | None,
    rms_gamma: torch.Tensor,
    rms_eps: float,
    allreduce_params: "FlashInferFusedAllReduceParams",
175
    workspace: object,
176
177
178
179
    use_oneshot: bool,
    norm_out: torch.Tensor | None = None,
):
    """FlashInfer fused allreduce + rmsnorm operation."""
180
    if flashinfer_comm is None or workspace is None:
181
182
183
184
185
186
187
188
        raise RuntimeError("FlashInfer not available or workspace not initialized")

    if norm_out is None:
        norm_out = input_tensor
        residual_out = residual
    else:
        residual_out = input_tensor

189
190
191
192
    layout_code = None
    if workspace.backend == "trtllm":
        layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4

193
194
    flashinfer_comm.allreduce_fusion(
        input=input_tensor,
195
        workspace=workspace,
196
        pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
197
198
199
200
201
202
203
        residual_in=residual,
        residual_out=residual_out,
        norm_out=norm_out,
        rms_gamma=rms_gamma,
        rms_eps=rms_eps,
        quant_out=None,
        scale_out=None,
204
        layout_code=layout_code,
205
206
        scale_factor=None,
        use_oneshot=use_oneshot,
207
        **allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
208
209
210
211
212
213
214
215
216
217
    )


def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
    input_tensor: torch.Tensor,
    residual: torch.Tensor | None,
    rms_gamma: torch.Tensor,
    rms_eps: float,
    scale_factor: torch.Tensor,
    allreduce_params: FlashInferFusedAllReduceParams,
218
    workspace: object,
219
220
221
222
    use_oneshot: bool = True,
    norm_out: torch.Tensor | None = None,
    quant_out: torch.Tensor | None = None,
):
223
224
225
226
227
    """FlashInfer fused allreduce + rmsnorm + FP8 quantization.

    Note: Only supported by the trtllm backend.
    """
    if flashinfer_comm is None or workspace is None:
228
229
230
231
232
233
234
235
        raise RuntimeError("FlashInfer not available or workspace not initialized")

    if norm_out is None:
        norm_out = input_tensor
        residual_out = residual
    else:
        residual_out = input_tensor

236
237
    flashinfer_comm.allreduce_fusion(
        input=input_tensor,
238
        workspace=workspace,
239
        pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
240
241
242
243
244
245
246
247
248
249
        residual_in=residual,
        residual_out=residual_out,
        norm_out=norm_out,
        rms_gamma=rms_gamma,
        rms_eps=rms_eps,
        quant_out=quant_out,
        scale_out=None,
        layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
        scale_factor=scale_factor,
        use_oneshot=use_oneshot,
250
        **allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
251
252
253
254
255
256
257
258
259
260
    )


def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
    input_tensor: torch.Tensor,
    residual: torch.Tensor | None,
    rms_gamma: torch.Tensor,
    rms_eps: float,
    input_global_scale: torch.Tensor,
    allreduce_params: FlashInferFusedAllReduceParams,
261
    workspace: object,
262
263
264
265
266
    quant_out: torch.Tensor,
    use_oneshot: bool,
    output_scale: torch.Tensor,
    norm_out: torch.Tensor | None = None,
):
267
268
269
270
271
    """FlashInfer fused allreduce + rmsnorm + FP4 quantization.

    Note: Only supported by the trtllm backend.
    """
    if flashinfer_comm is None or workspace is None:
272
273
274
275
276
277
278
279
        raise RuntimeError("FlashInfer not available or workspace not initialized")

    if norm_out is None:
        norm_out = input_tensor
        residual_out = residual
    else:
        residual_out = input_tensor

280
281
    flashinfer_comm.allreduce_fusion(
        input=input_tensor,
282
        workspace=workspace,
283
        pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
284
285
286
287
288
289
290
291
292
293
        residual_in=residual,
        residual_out=residual_out,
        norm_out=norm_out,
        rms_gamma=rms_gamma,
        rms_eps=rms_eps,
        quant_out=quant_out,
        scale_out=output_scale,
        layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
        scale_factor=input_global_scale,
        use_oneshot=use_oneshot,
294
        **allreduce_params.get_flashinfer_fused_allreduce_kwargs(),
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    )


class VllmFusedAllreduce:
    def __init__(self, hidden_dim, dtype):
        self.rms_eps = 1e-6
        self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype)
        self.fp8_quant = QuantFP8(
            static=True,
            group_shape=GroupShape.PER_TENSOR,
        )

    def allreduce_rmsnorm(
        self, input_tensor: torch.Tensor, residual: torch.Tensor | None
    ):
        allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
        return self.rms_norm(allreduce_out, residual)

    def allreduce_rmsnorm_fp8_quant(
        self,
        input_tensor: torch.Tensor,
        residual: torch.Tensor | None,
        scale_factor: torch.Tensor,
    ):
        allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
        rms_out = self.rms_norm(allreduce_out, residual)
        if residual is None:
            quant_out = self.fp8_quant(rms_out, scale_factor)
            return quant_out
        else:
            rms_out, residual_out = rms_out
            quant_out = self.fp8_quant(rms_out, scale_factor)
            return quant_out, residual_out

    def allreduce_rmsnorm_fp4_quant(
        self,
        input_tensor: torch.Tensor,
        residual: torch.Tensor | None,
        input_global_scale: torch.Tensor,
        quant_out: torch.Tensor,
        output_scale: torch.Tensor,
    ):
        allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        rms_output = self.rms_norm(allreduce_out, residual)
        if residual is None:
            rms_out = rms_output
        else:
            rms_out, residual_out = rms_output

        SCALED_FP4_QUANT_OUT_OP(
            rms_out,
            input_global_scale,
            True,
            output=quant_out,
            output_scale=output_scale,
        )

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        if residual is None:
            return quant_out, output_scale
        else:
            return quant_out, residual_out, output_scale


def create_test_tensors(
    num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True
):
    """Create test tensors for benchmarking."""
    input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype)
    residual = (
        torch.randn_like(input_tensor)
        if use_residual
        else torch.zeros_like(input_tensor)
    )
    rms_gamma = torch.ones(hidden_dim, dtype=dtype)
    norm_out = None if use_residual else torch.empty_like(input_tensor)

    # Quantization scales
    scale_fp8 = torch.tensor(1.0, dtype=torch.float32)
    scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
    quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
    # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
376
377
378
    fp4_quant_out, fp4_output_scale = create_fp4_output_tensors(
        num_tokens, hidden_dim, input_tensor.device, True
    )
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

    return (
        input_tensor,
        norm_out,
        residual,
        rms_gamma,
        scale_fp8,
        quant_out_fp8,
        scale_fp4,
        fp4_quant_out,
        fp4_output_scale,
    )


def benchmark_operation(
    operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs
):
    """Benchmark a single operation using CUDA graphs."""
    # Warmup before graph capture
    for _ in range(warmup):
        operation_func(*args, **kwargs)
400
    torch.accelerator.synchronize()
401
402
403
404
405
406

    # Create CUDA graph
    graph = torch.cuda.CUDAGraph()
    num_op_per_cudagraph = 10

    # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
407
    device = torch.device(f"cuda:{torch.accelerator.current_device_index()}")
408
409
410
411
412
    with graph_capture(device=device), torch.cuda.graph(graph):
        for _ in range(num_op_per_cudagraph):
            operation_func(*args, **kwargs)

    # Graph warmup
413
    torch.accelerator.synchronize()
414
415
416
417
    for _ in range(warmup):
        graph.replay()

    # Benchmark with CUDA graph
418
    torch.accelerator.synchronize()
419
420
421
422
423
424
    start_time = time.perf_counter()

    for _ in range(trials // num_op_per_cudagraph):
        # operation_func(*args, **kwargs)
        graph.replay()

425
    torch.accelerator.synchronize()
426
427
428
429
430
431
432
433
434
435
436
437
    end_time = time.perf_counter()

    avg_time_ms = ((end_time - start_time) / trials) * 1000
    return avg_time_ms


def run_benchmarks(
    num_tokens: int,
    hidden_dim: int,
    dtype: torch.dtype,
    use_residual: bool,
    allreduce_params: FlashInferFusedAllReduceParams | None,
438
    workspaces: dict,
439
440
441
442
443
444
    quant_modes: set[str],
    no_oneshot: bool,
):
    """Run all benchmarks for given configuration.

    Args:
445
446
447
        allreduce_params: Shared parameters for FlashInfer fused allreduce.
        workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
        quant_modes: Set of quantization modes: "none", "fp8", "fp4".
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    """
    (
        input_tensor,
        norm_out,
        residual,
        rms_gamma,
        scale_fp8,
        quant_out_fp8,
        scale_fp4,
        fp4_quant_out,
        fp4_output_scale,
    ) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual)

    rms_eps = 1e-6
    results = {}
    use_oneshot_options = [False] if no_oneshot else [True, False]

    if "none" in quant_modes:
        # Standard AllReduce + RMSNorm
467
468
        # Re-create VllmFusedAllreduce per config so CustomOp binds the
        # correct forward method (native vs custom kernel).
469
470
471
472
473
        for custom_op in ["-rms_norm", "+rms_norm"]:
            with set_current_vllm_config(
                VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
            ):
                try:
474
                    vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
                    suffix = (
                        "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
                    )
                    time_ms = benchmark_operation(
                        vllm_fused_allreduce.allreduce_rmsnorm,
                        input_tensor,
                        residual=residual,
                    )
                    results[f"standard_allreduce_{suffix}"] = time_ms
                except Exception as e:
                    logger.error("Standard AllReduce+RMSNorm failed: %s", e)
                    results[f"standard_allreduce_{suffix}"] = float("inf")

        # Standard AllReduce + RMSNorm Native Compiled
        with set_current_vllm_config(
            VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
        ):
            try:
493
                vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                standard_allreduce_rmsnorm_native_compiled = torch.compile(
                    vllm_fused_allreduce.allreduce_rmsnorm,
                    fullgraph=True,
                    dynamic=False,
                )
                time_ms = benchmark_operation(
                    standard_allreduce_rmsnorm_native_compiled,
                    input_tensor,
                    residual=residual,
                )
                results["standard_allreduce_rmsnorm_native_compiled"] = time_ms
            except Exception as e:
                logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
                results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")

509
510
        # FlashInfer Fused AllReduce + RMSNorm (all backends)
        for backend, workspace in workspaces.items():
511
512
            for use_oneshot in use_oneshot_options:
                suffix = "_oneshot" if use_oneshot else "_twoshot"
513
                key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}"
514
515
516
517
518
519
520
521
522
                try:
                    time_ms = benchmark_operation(
                        flashinfer_fused_allreduce_rmsnorm,
                        input_tensor,
                        residual=residual,
                        norm_out=norm_out,
                        rms_gamma=rms_gamma,
                        rms_eps=rms_eps,
                        allreduce_params=allreduce_params,
523
                        workspace=workspace,
524
525
                        use_oneshot=use_oneshot,
                    )
526
                    results[key] = time_ms
527
                except Exception as e:
528
529
530
531
                    logger.error(
                        "FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s",
                        backend,
                        e,
532
                    )
533
                    results[key] = float("inf")
534
535
536
537
538
539
540
541

    if "fp8" in quant_modes:
        # Standard AllReduce + RMSNorm + FP8 Quant
        for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]:
            suffix = (
                "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
            )
            for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
542
                op_suffix = suffix + (
543
544
545
546
547
548
549
550
551
552
553
554
                    "_custom_quant_fp8"
                    if "+" in quant_fp8_custom_op
                    else "_native_quant_fp8"
                )
                with set_current_vllm_config(
                    VllmConfig(
                        compilation_config=CompilationConfig(
                            custom_ops=[rms_norm_custom_op, quant_fp8_custom_op]
                        )
                    )
                ):
                    try:
555
                        vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
556
557
558
559
560
561
                        time_ms = benchmark_operation(
                            vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
                            input_tensor,
                            residual=residual,
                            scale_factor=scale_fp8,
                        )
562
                        results[f"standard_allreduce{op_suffix}"] = time_ms
563
564
                    except Exception as e:
                        logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
565
                        results[f"standard_allreduce{op_suffix}"] = float("inf")
566
567
568
569
570
571
572
573
574
575

        # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
        with set_current_vllm_config(
            VllmConfig(
                compilation_config=CompilationConfig(
                    custom_ops=["-rms_norm", "-quant_fp8"]
                )
            )
        ):
            try:
576
                vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
                standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
                    vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
                    fullgraph=True,
                    dynamic=False,
                )
                time_ms = benchmark_operation(
                    standard_allreduce_rmsnorm_fp8_quant_native_compiled,
                    input_tensor,
                    residual=residual,
                    scale_factor=scale_fp8,
                )
                results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = (
                    time_ms
                )
            except Exception as e:
                logger.error(
                    "Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e
                )
                results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float(
                    "inf"
                )

599
600
601
        # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
        if "trtllm" in workspaces:
            trtllm_ws = workspaces["trtllm"]
602
603
            for use_oneshot in use_oneshot_options:
                suffix = "_oneshot" if use_oneshot else "_twoshot"
604
                key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}"
605
606
607
608
609
610
611
612
613
614
615
                try:
                    time_ms = benchmark_operation(
                        flashinfer_fused_allreduce_rmsnorm_fp8_quant,
                        input_tensor,
                        norm_out=norm_out,
                        residual=residual,
                        rms_gamma=rms_gamma,
                        rms_eps=rms_eps,
                        scale_factor=scale_fp8,
                        quant_out=quant_out_fp8,
                        allreduce_params=allreduce_params,
616
                        workspace=trtllm_ws,
617
618
                        use_oneshot=use_oneshot,
                    )
619
                    results[key] = time_ms
620
621
                except Exception as e:
                    logger.error(
622
                        "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s",
623
624
                        e,
                    )
625
                    results[key] = float("inf")
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

    if "fp4" in quant_modes and current_platform.has_device_capability(100):
        # Standard AllReduce + RMSNorm + FP4 Quant
        for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]:
            suffix = (
                "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
            )
            with set_current_vllm_config(
                VllmConfig(
                    compilation_config=CompilationConfig(
                        custom_ops=[rms_norm_custom_op]
                    )
                )
            ):
                try:
641
                    vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
                    time_ms = benchmark_operation(
                        vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
                        input_tensor,
                        residual=residual,
                        input_global_scale=scale_fp4,
                        quant_out=fp4_quant_out,
                        output_scale=fp4_output_scale,
                    )
                    results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms
                except Exception as e:
                    logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e)
                    results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf")

        # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
        with set_current_vllm_config(
            VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
        ):
            try:
660
                vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
                standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
                    vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
                    fullgraph=True,
                    dynamic=False,
                )
                time_ms = benchmark_operation(
                    standard_allreduce_rmsnorm_fp4_quant_native_compiled,
                    input_tensor,
                    residual=residual,
                    quant_out=fp4_quant_out,
                    input_global_scale=scale_fp4,
                    output_scale=fp4_output_scale,
                )
                results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = (
                    time_ms
                )
            except Exception as e:
                logger.error(
                    "Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e
                )
                results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float(
                    "inf"
                )

685
686
687
        # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
        if "trtllm" in workspaces:
            trtllm_ws = workspaces["trtllm"]
688
689
            for use_oneshot in use_oneshot_options:
                suffix = "_oneshot" if use_oneshot else "_twoshot"
690
                key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}"
691
692
693
694
695
696
697
698
699
700
                try:
                    time_ms = benchmark_operation(
                        flashinfer_fused_allreduce_rmsnorm_fp4_quant,
                        input_tensor,
                        residual=residual,
                        norm_out=norm_out,
                        rms_gamma=rms_gamma,
                        rms_eps=rms_eps,
                        input_global_scale=scale_fp4,
                        allreduce_params=allreduce_params,
701
                        workspace=trtllm_ws,
702
703
704
705
                        quant_out=fp4_quant_out,
                        output_scale=fp4_output_scale,
                        use_oneshot=use_oneshot,
                    )
706
                    results[key] = time_ms
707
708
                except Exception as e:
                    logger.error(
709
                        "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s",
710
711
                        e,
                    )
712
                    results[key] = float("inf")
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998

    return results


def prepare_results_with_speedups(results_dict):
    """Prepare results with speedup calculations based on dynamic baseline selection."""
    prepared_results = []

    # Determine the fastest baseline for each operation type
    def get_fastest_baseline(op_name, results_dict):
        """Get the fastest baseline between standard and native_compiled versions."""
        if "fp8_quant" in op_name:
            candidates = [
                "standard_allreduce_rmsnorm_fp8_quant",
                "standard_allreduce_rmsnorm_fp8_quant_native_compiled",
            ]
        elif "fp4_quant" in op_name:
            candidates = [
                "standard_allreduce_rmsnorm_fp4_quant",
                "standard_allreduce_rmsnorm_fp4_quant_native_compiled",
            ]
        else:
            candidates = [
                "standard_allreduce_rmsnorm",
                "standard_allreduce_rmsnorm_native_compiled",
            ]

        # Find the fastest among available candidates
        fastest_time = float("inf")
        fastest_baseline = None

        for candidate in candidates:
            if (
                candidate in results_dict
                and results_dict[candidate] != float("inf")
                and results_dict[candidate] < fastest_time
            ):
                fastest_time = results_dict[candidate]
                fastest_baseline = candidate

        return fastest_baseline

    # Create dynamic baseline mapping
    dynamic_baseline_mapping = {}
    for op_name in results_dict:
        if (
            op_name.startswith("flashinfer_")
            or op_name.startswith("standard_")
            and not op_name.endswith("_native_compiled")
        ):
            dynamic_baseline_mapping[op_name] = get_fastest_baseline(
                op_name, results_dict
            )

    for op_name, time_ms in results_dict.items():
        if time_ms == float("inf"):
            speedup_str = "FAILED"
            time_str = "FAILED"
        else:
            time_str = f"{time_ms:.3f}"
            # Find the appropriate baseline for this operation
            baseline_op = dynamic_baseline_mapping.get(op_name)
            if baseline_op and baseline_op in results_dict:
                baseline_time = results_dict[baseline_op]
                if baseline_time != float("inf") and baseline_time > 0:
                    speedup = baseline_time / time_ms
                    speedup_str = f"{speedup:.2f}x"
                else:
                    speedup_str = "N/A"
            else:
                # For baseline operations, determine if this is the fastest baseline
                if op_name.endswith("_native_compiled") or (
                    op_name.startswith("standard_")
                    and not op_name.endswith("_native_compiled")
                ):
                    fastest_baseline = get_fastest_baseline(op_name, results_dict)
                    if fastest_baseline == op_name:
                        speedup_str = "baseline"
                    else:
                        if fastest_baseline and fastest_baseline in results_dict:
                            baseline_time = results_dict[fastest_baseline]
                            if baseline_time != float("inf") and baseline_time > 0:
                                speedup = baseline_time / time_ms
                                speedup_str = f"{speedup:.2f}x"
                            else:
                                speedup_str = "N/A"
                        else:
                            speedup_str = "N/A"
                else:
                    speedup_str = "N/A"

        prepared_results.append(
            {
                "operation": op_name,
                "time_ms": time_ms,
                "time_str": time_str,
                "speedup_str": speedup_str,
            }
        )

    return prepared_results


def print_results(
    results_dict,
    num_tokens,
    hidden_dim,
    dtype,
    use_residual,
    quant_modes,
    input_size_mb,
):
    """Print benchmark results in a formatted table."""
    print(f"\n{'=' * 80}")
    print(
        f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} "
        f"(input size: {input_size_mb:.2f} MB)"
    )
    print(
        f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, "
        f"quant_modes={','.join(sorted(list(quant_modes)))}"
    )
    print(f"{'=' * 80}")
    print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}")
    print(f"{'-' * 80}")

    # Prepare results with speedup calculations
    prepared_results = prepare_results_with_speedups(results_dict)

    for result in prepared_results:
        if result["time_ms"] == float("inf"):
            time_display = result["time_str"]
        else:
            time_display = f"{result['time_ms']:.3f}"

        print(
            f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}"
        )


def format_results_markdown(
    all_results: list[dict], world_size: int, args: argparse.Namespace
) -> str:
    """Format all benchmark results as markdown."""
    lines: list[str] = []
    lines.append("# FlashInfer Fused Collective Operations Benchmark Results")
    lines.append("")
    lines.append(f"**World Size:** {world_size}  ")
    lines.append(f"**Hidden Dimension:** {args.hidden_dim}  ")
    lines.append(f"**Warmup Iterations:** {args.warmup}  ")
    lines.append(f"**Benchmark Trials:** {args.trials}  ")
    modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A"
    lines.append(f"**Quantization Modes:** {modes}  ")
    lines.append("")
    lines.append("---")
    lines.append("")

    for entry in all_results:
        num_tokens = entry["num_tokens"]
        dtype = entry["dtype"]
        use_residual = entry["use_residual"]
        results_dict = entry["results"]
        input_size_mb = entry["input_size_mb"]
        residual_str = "with residual" if use_residual else "no residual"

        lines.append(
            f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}"
        )
        lines.append(f"**Input Size:** {input_size_mb:.2f} MB")
        lines.append("")

        prepared = prepare_results_with_speedups(results_dict)
        # Build DataFrame for markdown export
        rows = [
            {
                "Operation": r["operation"].replace("_", " ").title(),
                "Time (ms)": r["time_str"],
                "Speedup": r["speedup_str"],
            }
            for r in prepared
        ]
        df = pd.DataFrame(rows)
        if df.empty:
            lines.append("No results.")
        else:
            lines.append(df.to_markdown(index=False))
        lines.append("")

    return "\n".join(lines)


def save_results_to_file(
    all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int
):
    """Save benchmark results to markdown file (only on rank 0)."""
    if rank != 0:
        return

    if not all_results:
        logger.warning("No results to save")
        return

    output_path = args.output_file

    try:
        markdown_content = format_results_markdown(all_results, world_size, args)

        with open(output_path, "a") as f:
            f.write(markdown_content)

    except Exception as e:
        logger.error("Failed to save results to file: %s", e)


def main():
    parser = argparse.ArgumentParser(
        description="Benchmark fused collective operations"
    )
    parser.add_argument(
        "--num-tokens",
        type=int,
        nargs="+",
        default=[128, 512, 1024, 2048],
        help="Numbers of tokens to test",
    )
    parser.add_argument(
        "--hidden-dim", type=int, default=8192, help="Hidden dimension size"
    )
    parser.add_argument(
        "--dtypes",
        type=str,
        nargs="+",
        default=["bfloat16"],
        choices=["float16", "bfloat16", "float32"],
        help="Data types to test",
    )
    parser.add_argument(
        "--no-residual",
        action="store_true",
        help="Skip residual connection tests",
    )

    parser.add_argument(
        "--quant-modes",
        type=str,
        default="none,fp8,fp4",
        help=(
            "Comma-separated quantization modes to run: none, fp8, fp4. "
            "Default: none,fp8,fp4"
        ),
    )

    parser.add_argument(
        "--warmup", type=int, default=5, help="Number of warmup iterations"
    )
    parser.add_argument(
        "--trials", type=int, default=20, help="Number of benchmark trials"
    )
    parser.add_argument(
        "--output-file",
        type=str,
        help="""Output file path for markdown results 
                (default: benchmark_results_<timestamp>.md)
        """,
    )

    parser.add_argument(
        "--no-oneshot",
        action="store_true",
        help="Skip oneshot benchmarks",
    )

    args = parser.parse_args()

    # Check if running with torchrun (required for collective operations)
    if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
        raise RuntimeError(
            "Must run with torchrun for distributed benchmarking. "
            "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
        )

    # Initialize distributed environment
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    device = torch.device(f"cuda:{rank}")
999
    torch.accelerator.set_device_index(device)
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    torch.set_default_device(device)

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # Validate world size (must be > 1 for collective operations)
    if world_size <= 1:
        raise ValueError(
            "World size must be > 1 for collective operations benchmarking. "
            f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1."
        )

    # Parse quantization modes
    valid_quant_modes = {"none", "fp8", "fp4"}
    raw_modes = [
        m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip()
    ]
    quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"}
    invalid = sorted(list(quant_modes - valid_quant_modes))
    if invalid:
        raise ValueError(
            f"Invalid --quant-modes entries: {','.join(invalid)}. "
            f"Valid options are: {','.join(sorted(valid_quant_modes))}."
        )

    if rank == 0:
        logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank)
        logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes))))
        if flashinfer_comm is not None:
            logger.info(
                "FlashInfer available - will benchmark fused operations",
            )
        else:
            logger.info(
                "FlashInfer not available - only benchmarking standard operations"
            )

    # Convert dtype strings to torch dtypes
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }
    dtypes = [dtype_map[dt] for dt in args.dtypes]

    # Test configurations
    residual_options = [True] if not args.no_residual else [False]

    configs = list(itertools.product(args.num_tokens, dtypes, residual_options))

1050
    # Setup FlashInfer workspaces for all backends
1051
1052
1053
1054
    allreduce_params = None

    if flashinfer_comm is not None:
        # Use the largest hidden dimension for workspace setup
1055
1056
1057
1058
1059
1060
        max_element_size = max(torch.finfo(dt).bits // 8 for dt in dtypes)
        workspace_dtype = (
            torch.float32
            if max_element_size == 4
            else (torch.bfloat16 if torch.bfloat16 in dtypes else torch.float16)
        )
1061
        max_num_token = _FI_MAX_SIZES.get(world_size) // (
1062
            args.hidden_dim * max_element_size
1063
1064
        )

1065
1066
1067
1068
1069
1070
1071
1072
1073
        for backend in FLASHINFER_BACKENDS:
            setup_flashinfer_workspace(
                backend=backend,
                world_size=world_size,
                rank=rank,
                hidden_dim=args.hidden_dim,
                max_token_num=max_num_token,
                dtype=workspace_dtype,
            )
1074

1075
        if _FI_WORKSPACES:
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
            allreduce_params = FlashInferFusedAllReduceParams(
                max_token_num=max_num_token,
            )

    # Collect all results for markdown export
    all_results = []

    try:
        # Run benchmarks
        for num_tokens, dtype, use_residual in configs:
            if rank == 0:
                logger.info(
                    "\nTesting:  num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s",
                    num_tokens,
                    args.hidden_dim,
                    dtype,
                    use_residual,
                )

            results = run_benchmarks(
                num_tokens,
                args.hidden_dim,
                dtype,
                use_residual,
                allreduce_params,
1101
                workspaces=_FI_WORKSPACES,
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
                quant_modes=quant_modes,
                no_oneshot=args.no_oneshot,
            )

            # Store results for markdown export
            if rank == 0:
                # Calculate input size in MB
                input_size_mb = (
                    num_tokens * args.hidden_dim * torch.finfo(dtype).bits
                ) / (8 * 1024 * 1024)
                all_results.append(
                    {
                        "num_tokens": num_tokens,
                        "hidden_dim": args.hidden_dim,
                        "dtype": str(dtype).replace("torch.", ""),
                        "use_residual": use_residual,
                        "quant_modes": sorted(list(quant_modes)),
                        "input_size_mb": input_size_mb,
                        "results": results,
                    }
                )

                print_results(
                    results,
                    num_tokens,
                    args.hidden_dim,
                    dtype,
                    use_residual,
                    quant_modes,
                    input_size_mb,
                )

        # Save results to markdown file
        if args.output_file and rank == 0:
            save_results_to_file(all_results, world_size, args, rank)

    finally:
        # Cleanup
1140
        cleanup_flashinfer_workspaces()
1141
1142
1143
1144
1145

        dist.barrier()


if __name__ == "__main__":
1146
1147
1148
1149
    from vllm.config import VllmConfig, set_current_vllm_config

    with set_current_vllm_config(VllmConfig()):
        main()