collective_fusion.py 43.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from importlib.util import find_spec
4
5
6
7

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
11
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
15
from vllm.distributed.parallel_state import (
16
17
18
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
19
from vllm.logger import init_logger
20
21
22
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
23
from vllm.platforms import current_platform
24
from vllm.utils.torch_utils import direct_register_custom_op
25

26
from .inductor_pass import enable_fake_mode
27
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
28
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
29

30
31
FP8_DTYPE = current_platform.fp8_dtype()

32
if find_spec("flashinfer"):
33
34
    try:
        import flashinfer.comm as flashinfer_comm
35
36
37
38
39
40

        flashinfer_comm = (
            flashinfer_comm
            if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
            else None
        )
41
42
    except ImportError:
        flashinfer_comm = None
43
44
45
else:
    flashinfer_comm = None

46
47
logger = init_logger(__name__)

48
49
if hasattr(torch.ops._C, "scaled_fp4_quant"):
    STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

class BasePattern:
    def __init__(self, dtype: torch.dtype, device: str):
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()


class GEMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
73
74
                group_name=self.tp.unique_name,
            )
75
76
77
78
79
80
81
82
83
84
85
86
87
            return reduce_scatter

        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

88
89
90
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108


class AllGatherGEMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
109
110
                group_name=self.tp.unique_name,
            )
111
112
113
114

            return torch.ops.aten.mm.default(all_gather, weight)

        def replacement(
115
116
            x: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
117
118
119
120
121
122
123
124
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

125
126
127
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
128
129


130
131
132
class ScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
133
134
135
136
137
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
138
139
140
141
142
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        def pattern(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(
                input,
                mat2=mat2,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )
158
159
160
161
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
162
163
                group_name=self.tp.unique_name,
            )
164
165
            return reduce_scatter

166
167
168
169
170
171
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
172
173
174
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
175
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
176
177
178
179
180
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
181
182
183
184
185
186
187
188
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
189
190
191
192
            )

            return gemm_rs

193
194
195
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
196
197
198
199
200


class AllGatherScaledMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
201
202
203
204
205
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            return torch.ops.aten._scaled_mm.default(
                all_gather,
                mat2=weight,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

255
256
257
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
258
259
260
261
262


class CutlassScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
263
264
265
266
267
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
268
269
270
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

271
        cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
272
273
274
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

    def register(self, pm_pass: PatternMatcherPass):
275
276
277
278
279
280
281
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
282
283
284
285
286
287
288
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
289
290
                bias=None,
            )
291
292
293
294
295

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
296
297
                group_name=self.tp.unique_name,
            )
298
299
            return reduce_scatter

300
301
302
303
304
305
306
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
307
308
309
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
310
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
311
312
313
314
315
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
316
317
318
319
320
321
322
323
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
324
325
326
327
            )

            return gemm_rs

328
329
330
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
331
332
333
334
335


class AllGatherCutlassScaledMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
336
337
338
339
340
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
361
362
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )
363
364
365
366
367
368
369
370

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
371
372
                bias=None,
            )
373
374
            return cutlass_scaled_mm[1]

375
376
377
378
379
380
381
        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

396
397
398
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
399
400


401
class AsyncTPPass(VllmPatternMatcherPass):
402
    @enable_fake_mode
403
404
405
406
407
408
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
409
410
411
            pass_name="async_tp_pass"
        )
        GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
412

413
        AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
414

415
416
417
418
        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
419
420
421
422
423
424
            ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
425

426
427
428
429
430
431
            CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
432

433
434
        self.dump_patterns(config, self.patterns)

435
436
437
438
439
440
441
442
443
    def is_applicable(self, shape: int | None) -> bool:
        # This pass is applied on top of the sequence parallelism pass.
        # It inherits the same applicability condition as `SequenceParallelismPass`.
        # See `SequenceParallelismPass.is_applicable` for more details.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
444
445
446
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

447
    @VllmInductorPass.time_and_log
448
    def __call__(self, graph: fx.Graph):
449
450
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
451
452
453
454
455
456
457
458
459


if flashinfer_comm is not None:
    _FI_WORKSPACE_TENSOR = None

    MiB = 1024 * 1024
    # Max size of the input tensor per world size
    # to use flashinfer fused allreduce
    _FI_MAX_SIZES = {
460
        2: 64 * MiB,  # 64MB
461
462
463
464
        4: MiB,  # 1MB
        6: MiB // 2,  # 512KB
        8: MiB // 2,  # 512KB
    }
465
466

    try:
467
468
469
470
471
472
        _FI_MAX_SIZES.update(
            {
                int(k): int(float(v) * MiB)
                for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
            }
        )
473
474
    except Exception as e:
        raise ValueError(
475
476
            "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e)
        ) from e
477

478
479
480
    # opt for a more conservative default value
    # when world size is not in _FI_MAX_SIZES
    _DEFAULT_FI_MAX_SIZE = MiB // 2
481
482
483
484
485
486
487
488
489
490
491
492

    def call_trtllm_fused_allreduce_norm(
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
493
494
        pattern_code: int,
        fuse_rms_quant: bool,
495
496
497
498
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
499
    ) -> None:
500
501
502
503
504
505
506
507
        num_tokens, hidden_size = allreduce_in.shape
        element_size = allreduce_in.element_size()
        current_tensor_size = num_tokens * hidden_size * element_size
        max_fusion_size = max_token_num * hidden_size * element_size
        use_flashinfer = current_tensor_size <= min(
            _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
            max_fusion_size,
        )
508
        if use_flashinfer:
509
510
511
            assert _FI_WORKSPACE_TENSOR is not None, (
                "Flashinfer must be enabled when using flashinfer"
            )
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
            if norm_out is None:
                norm_out = allreduce_in
                residual_out = residual
            else:
                # return residual_out as allreduce_out with zeroed residual_in
                # as flashinfer does not support rms_norm
                # and allreduce_out together
                residual_out = allreduce_in
            # For the sizes that are smaller than the max size,
            # we only use flashinfer one shot allreduce
            flashinfer_comm.trtllm_allreduce_fusion(
                allreduce_in=allreduce_in,
                token_num=allreduce_in.shape[0],
                residual_in=residual,
                residual_out=residual_out,
                norm_out=norm_out,
                rms_gamma=rms_gamma,
                rms_eps=rms_eps,
                world_rank=world_rank,
                world_size=world_size,
                hidden_dim=allreduce_in.shape[-1],
                workspace_ptrs=_FI_WORKSPACE_TENSOR,
                launch_with_pdl=launch_with_pdl,
                use_oneshot=True,
                trigger_completion_at_end=trigger_completion_at_end,
                fp32_acc=fp32_acc,
538
                pattern_code=pattern_code,
539
                allreduce_out=None,
540
541
542
                quant_out=quant_out,
                scale_out=scale_out,
                # in vllm we only support swizzled layout
543
                layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
544
                scale_factor=scale_factor,
545
546
547
            )
        else:
            allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
548
            if scale_factor is not None and scale_out is None and fuse_rms_quant:
549
550
551
                # Do fused rms norm static fp8 quant fused op
                if norm_out is None:
                    torch.ops._C.fused_add_rms_norm_static_fp8_quant(
552
553
554
555
556
557
558
                        quant_out,
                        allreduce_out,
                        residual,
                        rms_gamma,
                        scale_factor,
                        rms_eps,
                    )
559
560
                else:
                    torch.ops._C.rms_norm_static_fp8_quant(
561
562
                        quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps
                    )
563
            else:
564
                if norm_out is None:
565
566
567
                    torch.ops._C.fused_add_rms_norm(
                        allreduce_out, residual, rms_gamma, rms_eps
                    )
568
569
                    norm_out = allreduce_out
                else:
570
                    torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
571
572
                if scale_factor is not None:
                    if scale_out is not None:
573
574
575
                        torch.ops._C.scaled_fp4_quant(
                            quant_out, norm_out, scale_out, scale_factor
                        )
576
577
                    else:
                        torch.ops._C.static_scaled_fp8_quant(
578
579
                            quant_out, norm_out, scale_factor
                        )
580
            if scale_factor is None or norm_out is not None:
581
                # we need to return allreduce output
582
583
584
                # in cases of non quant fused AR + RMS norm
                # and fused AR + RMS norm + quant without fused add
                allreduce_in.copy_(allreduce_out)
585
586

    def call_trtllm_fused_allreduce_norm_fake(
587
588
589
590
591
592
593
594
595
596
597
598
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
        pattern_code: int,
        fuse_rms_quant: bool,
599
600
601
602
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
603
    ) -> None:
604
605
606
607
608
609
610
611
612
        pass

    direct_register_custom_op(
        op_name="flashinfer_trtllm_fused_allreduce_norm",
        op_func=call_trtllm_fused_allreduce_norm,
        mutates_args=[
            "allreduce_in",
            "residual",
            "norm_out",
613
614
            "quant_out",
            "scale_out",
615
616
617
618
        ],
        fake_impl=call_trtllm_fused_allreduce_norm_fake,
    )
    flashinfer_trtllm_fused_allreduce_norm = (
619
620
        torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
    )
621
622
623
624
625
626
627
628
629
630
631


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        rank: int,
        world_size: int,
        use_fp32_lamport: bool = False,
        max_token_num: int = 1024,
632
        fuse_rms_quant: bool = False,
633
634
635
636
637
638
639
640
641
    ):
        self.rank = rank
        self.world_size = world_size
        self.use_fp32_lamport = use_fp32_lamport
        self.trigger_completion_at_end = True
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.use_oneshot = False
        self.max_token_num = max_token_num
642
        self.fuse_rms_quant = fuse_rms_quant
643
644
645
646
647
648
649
650
651

    def get_trtllm_fused_allreduce_kwargs(self):
        return {
            "world_rank": self.rank,
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "trigger_completion_at_end": self.trigger_completion_at_end,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
652
            "fuse_rms_quant": self.fuse_rms_quant,
653
654
655
        }


656
657
class AllReduceRMSNormPattern(BasePattern):
    """
658
    This pattern replaces the allreduce + rms norm (without residual)
659
660
661
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """
662
663
664
665
666
667
668
669
670
671
672

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
673
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
674
675

    def get_inputs(self):
676
        input, weight = self.rmsnorm_matcher.inputs()
677

678
679
        # input goes through allreduce first, always 16-bit
        return [input.to(self.dtype), weight]
680
681

    def register(self, pm_pass: PatternMatcherPass):
682
        def pattern(input: torch.Tensor, weight: torch.Tensor):
683
            allreduce_output = tensor_model_parallel_all_reduce(input)
684
            rms = self.rmsnorm_matcher(allreduce_output, weight)
685

686
687
688
            return rms, allreduce_output

        def replacement(input: torch.Tensor, weight: torch.Tensor):
689
            residual = torch.zeros_like(input)
690
            rms_result = torch.empty_like(input)
691
            allreduce = auto_functionalized(
692
                flashinfer_trtllm_fused_allreduce_norm,
693
694
695
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
696
697
                quant_out=None,
                scale_out=None,
698
699
                rms_gamma=weight,
                rms_eps=self.epsilon,
700
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
701
702
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
703
            # rms_result, allreduce_in
704
705
            return allreduce[3], allreduce[1]

706
707
708
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
709
710
711


class AllReduceFusedAddRMSNormPattern(BasePattern):
712
    """
713
    This pattern replaces the allreduce + rms norm (with residual)
714
715
716
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """
717
718
719
720
721
722
723
724
725
726
727

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
728
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
729
730

    def get_inputs(self):
731
732
733
734
        input, residual, weight = self.rmsnorm_matcher.inputs()

        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight]
735
736

    def register(self, pm_pass: PatternMatcherPass):
737
        def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
738
            allreduce_output = tensor_model_parallel_all_reduce(input)
739
740
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
            return rms, residual
741

742
743
744
        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ):
745
            allreduce = auto_functionalized(
746
                flashinfer_trtllm_fused_allreduce_norm,
747
748
                allreduce_in=input,
                residual=residual,
749
750
751
                norm_out=None,
                quant_out=None,
                scale_out=None,
752
753
                rms_gamma=weight,
                rms_eps=self.epsilon,
754
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
755
756
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
757
            # allreduce_in, residual
758
759
            return allreduce[1], allreduce[2]

760
761
762
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
763

764
765
766
767
768
769
770
771
772
773
774
775
        # Same pattern, but only return the output and not residual
        # (helpful for end of graph where residual is not used again)
        first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]

        pm.register_replacement(
            first_return_only(pattern),
            first_return_only(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
        )

776

777
778
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
779
    This pattern replaces the allreduce + rms norm (without residual)
780
    + static fp8 quant with fused flashinfer implementation.
781
    Applies to allreduce + rmsnorm + quant before attn
782
783
784
    in the first Transformer block.
    """

785
786
787
788
789
790
791
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
792
793
794
795
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn
796
797
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
798
799
800

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
801
802
803
804
805
            input, weight = self.rmsnorm_matcher.inputs()
            _, scale = self.quant_matcher.inputs()

            # input goes through allreduce first, always 16-bit
            return [input.to(self.dtype), weight, scale]
806
807
808
809
810
811
812

        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
813
814
815
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
816

817
        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
818
            residual = torch.zeros_like(input)
819
820
            result_rms = torch.empty_like(input)
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
821
822
823
824
825
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
826
                quant_out=result_quant,
827
828
829
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
830
831
832
833
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
834
835
836
837
838
839
840
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

841
842
843
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
844
845
846
847
848
849


class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
850
    Applies to o_proj + rmsnorm after attn + quant and
851
852
853
    mlp + rmsnorm + quant before attn.
    """

854
855
856
857
858
859
860
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
861
862
863
864
865
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

866
867
868
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

869
870
    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
871
872
            input, residual, weight = self.rmsnorm_matcher.inputs()
            _, scale = self.quant_matcher.inputs()
873

874
875
            # input goes through allreduce first, always 16-bit
            return [residual, input.to(self.dtype), weight, scale]
876
877
878
879
880
881
882
883

        def pattern(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            allreduce_output = tensor_model_parallel_all_reduce(input)
884
885
            rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
            quant, _ = self.quant_matcher(rms, scale)
886

887
            return quant, res
888

889
890
891
892
893
894
        def replacement(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
895
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
896
897
898
899
900
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
901
                quant_out=result_quant,
902
903
904
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
905
906
907
908
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
909
910
911
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
912
            # quant_out, rms_norm_residual
913
914
            return allreduce[4], allreduce[2]

915
916
917
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
918
919
920
921


class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
922
    This pattern replaces the allreduce + rms norm (without residual)
923
    + static nvfp4 quant with fused flashinfer implementation.
924
    Applies to allreduce + rmsnorm + quant before attn
925
926
927
    in the first Transformer block.
    """

928
929
930
931
932
933
934
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
935
936
937
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
938
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
939
940
941

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
942
943
944
945
946
            input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
            quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
            input_global_scale = torch.empty(
                [1, 1], device=self.device, dtype=torch.float32
            )
947
            weight = torch.empty([16], device=self.device, dtype=self.dtype)
948
            output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
949

950
            return [input, quant_result, weight, input_global_scale, output_scale]
951
952
953
954
955
956
957
958
959

        def pattern(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
960
            rms = self.rmsnorm_matcher(all_reduce, weight)
961
962
963
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
964
                input=rms,
965
                output_scale=output_scale,
966
967
                input_scale=input_global_scale,
            )
968
969
970
971

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

972
973
974
975
976
977
978
        def replacement(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
979
            residual = torch.zeros_like(input)
980
            result_rms = torch.empty_like(input)
981
982
983
984
985
986
987
988
989
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
990
991
992
993
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
994
995
996
997
998
999
1000
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

1001
1002
1003
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
1004
1005
1006
1007
1008
1009


class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
1010
    Applies to o_proj + rmsnorm after attn + quant and
1011
1012
1013
    mlp + rmsnorm + quant before attn.
    """

1014
1015
1016
1017
1018
1019
1020
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
1021
1022
1023
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
1024
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
1025
1026
1027
1028
1029

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
            input = torch.empty([16, 16], device=self.device, dtype=self.dtype)

1030
1031
1032
1033
1034
1035
1036
            residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
            weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
            quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
            input_global_scale = torch.empty(
                [1, 1], device=self.device, dtype=torch.float32
            )
            output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046

            return [
                quant_result,
                residual,
                input,
                output_scale,
                weight,
                input_global_scale,
            ]

1047
1048
1049
1050
1051
1052
1053
1054
        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ):
1055
            allreduce_output = tensor_model_parallel_all_reduce(input)
1056
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
1057
1058
1059
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
1060
                input=rms,
1061
                output_scale=output_scale,
1062
1063
                input_scale=input_global_scale,
            )
1064
1065

            # quant_out, allreduce_output, output_scale
1066
            return quant_out_tuple[1], residual, quant_out_tuple[2]
1067

1068
1069
1070
1071
1072
1073
1074
1075
        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ):
1076
1077
1078
1079
1080
1081
1082
1083
1084
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
1085
1086
1087
1088
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
1089
1090
1091
1092
1093
1094
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

1095
1096
1097
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
1098
1099


1100
class AllReduceFusionPass(VllmPatternMatcherPass):
1101
    def __init__(self, config: VllmConfig):
1102
1103
1104
1105
1106
1107
        super().__init__(config)
        self.disabled = True
        self.tp_size = get_tensor_model_parallel_world_size()
        if self.tp_size <= 1:
            return
        self.patterns: PatternMatcherPass = PatternMatcherPass(
1108
1109
            pass_name="all_reduce_fusion_pass"
        )
1110
1111
1112
1113
1114
1115
1116
1117
        if config.model_config is None:
            return
        self.hidden_dim = config.model_config.get_hidden_size()
        self.group = get_tp_group().device_group
        rank = get_tensor_model_parallel_rank()
        use_fp32_lamport = self.model_dtype == torch.float32
        if flashinfer_comm is None:
            logger.warning(
1118
                "Flashinfer is not installed or comm module not found, "
1119
1120
                "skipping allreduce fusion pass"
            )
1121
1122
1123
1124
            return
        # Check if the world size is supported
        if self.tp_size not in _FI_MAX_SIZES:
            logger.warning(
1125
                "Flashinfer allreduce fusion is not supported for world size %s",
1126
1127
1128
                self.tp_size,
            )
            return
1129
        max_num_token = min(
1130
1131
1132
1133
            _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE)
            // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
            config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num,
        )
1134
1135
1136
1137
        self.ipc_handles, workspace_tensor = (
            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
                tp_rank=rank,
                tp_size=self.tp_size,
1138
                max_token_num=max_num_token,
1139
1140
1141
                hidden_dim=self.hidden_dim,
                group=self.group,
                use_fp32_lamport=use_fp32_lamport,
1142
1143
            )
        )
1144
1145
1146
1147
1148
1149
1150

        global _FI_WORKSPACE_TENSOR
        _FI_WORKSPACE_TENSOR = workspace_tensor
        self.allreduce_params = FlashInferFusedAllReduceParams(
            rank=rank,
            world_size=self.tp_size,
            use_fp32_lamport=use_fp32_lamport,
1151
1152
1153
            max_token_num=max_num_token,
            # fuse rms norm static fp8 quant fused op
            # in fallback path, when we don't use flashinfer
1154
1155
            fuse_rms_quant=config.compilation_config.pass_config.enable_fusion,
        )
1156

1157
        self.register_patterns()
1158
        self.dump_patterns(config, self.patterns)
1159
1160
1161

    @enable_fake_mode
    def register_patterns(self):
1162
        for epsilon in [1e-5, 1e-6]:
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
            AllReduceFusedRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            if current_platform.has_device_capability(100):
                AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
                AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
            AllReduceRMSNormPattern(
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormPattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)

1201
1202
1203
1204
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

1205
1206
        self.disabled = False

1207
    @VllmInductorPass.time_and_log
1208
1209
    def __call__(self, graph: fx.Graph):
        if self.disabled:
1210
            logger.debug("AllReduceFusionPass disabled")
1211
            return
1212
1213
1214

        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
1215
1216

    def __del__(self):
1217
        if getattr(self, "disabled", True):
1218
1219
            return
        if flashinfer_comm is not None:
1220
            flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
1221
1222
                self.ipc_handles, self.group
            )