collective_fusion.py 44.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from importlib.util import find_spec
4
5
6
7
8
from typing import Optional

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
9
from torch._higher_order_ops.auto_functionalize import auto_functionalized
10
11
12
13
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

from vllm.config import VllmConfig
14
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
15
from vllm.distributed.parallel_state import (
16
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
17
from vllm.logger import init_logger
18
from vllm.platforms import current_platform
19
from vllm.utils import direct_register_custom_op
20
21
22

from .vllm_inductor_pass import VllmInductorPass

23
24
FP8_DTYPE = current_platform.fp8_dtype()

25
if find_spec("flashinfer"):
26
27
28
29
30
31
    try:
        import flashinfer.comm as flashinfer_comm
        flashinfer_comm = (flashinfer_comm if hasattr(
            flashinfer_comm, "trtllm_allreduce_fusion") else None)
    except ImportError:
        flashinfer_comm = None
32
33
34
else:
    flashinfer_comm = None

35
36
logger = init_logger(__name__)

37
38
39
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
40
41
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

class BasePattern:

    def __init__(self, dtype: torch.dtype, device: str):
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()


class GEMMReduceScatterPattern(BasePattern):

    def get_inputs(self):
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
68
69
                group_name=self.tp.unique_name,
            )
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
            return reduce_scatter

        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class AllGatherGEMMPattern(BasePattern):

    def get_inputs(self):
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
105
106
                group_name=self.tp.unique_name,
            )
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

            return torch.ops.aten.mm.default(all_gather, weight)

        def replacement(
                x: torch.Tensor,
                weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class ScaledMMReduceScatterPattern(BasePattern):

    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        mm_weight = torch.empty([16, 16], device=self.device,
                                dtype=FP8_DTYPE).contiguous().transpose(0, 1)
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(input: torch.Tensor, mat2: torch.Tensor,
                    scale_a: torch.Tensor,
                    scale_b: torch.Tensor) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(input,
                                                          mat2=mat2,
                                                          scale_a=scale_a,
                                                          scale_b=scale_b,
                                                          bias=None,
                                                          scale_result=None,
                                                          out_dtype=self.dtype)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name)
            return reduce_scatter

        def replacement(input: torch.Tensor, mat2: torch.Tensor,
                        scale_a: torch.Tensor,
                        scale_b: torch.Tensor) -> torch.Tensor:
            gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
                scatter_dim=0,
                out_dtype=self.dtype,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class AllGatherScaledMMPattern(BasePattern):

    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
        weight = torch.empty([16, 16], device=self.device,
                             dtype=FP8_DTYPE).contiguous().transpose(0, 1)

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name)

            return torch.ops.aten._scaled_mm.default(all_gather,
                                                     mat2=weight,
                                                     scale_a=scale_a,
                                                     scale_b=scale_b,
                                                     bias=None,
                                                     scale_result=None,
                                                     out_dtype=self.dtype)

        def replacement(x: torch.Tensor, weight: torch.Tensor,
                        scale_a: torch.Tensor,
                        scale_b: torch.Tensor) -> torch.Tensor:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class CutlassScaledMMReduceScatterPattern(BasePattern):

    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
        mm_weight = torch.empty([16, 16], device=self.device,
                                dtype=FP8_DTYPE).contiguous().transpose(0, 1)
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        cutlass_mm_output = torch.empty([16, 16],
                                        device=self.device,
                                        dtype=self.dtype)
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(input: torch.Tensor, weight: torch.Tensor,
                    scale_a: torch.Tensor, scale_b: torch.Tensor,
                    cutlass_mm_output: torch.Tensor) -> torch.Tensor:
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
                bias=None)

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name)
            return reduce_scatter

        def replacement(input: torch.Tensor, mat2: torch.Tensor,
                        scale_a: torch.Tensor, scale_b: torch.Tensor,
                        cutlass_mm_output: torch.Tensor) -> torch.Tensor:
            gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
                scatter_dim=0,
                out_dtype=self.dtype,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class AllGatherCutlassScaledMMPattern(BasePattern):

    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
        weight = torch.empty([16, 16], device=self.device,
                             dtype=FP8_DTYPE).contiguous().transpose(0, 1)

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
                group_name=self.tp.unique_name)

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
                bias=None)
            return cutlass_scaled_mm[1]

        def replacement(x: torch.Tensor, weight: torch.Tensor,
                        scale_a: torch.Tensor, scale_b: torch.Tensor,
                        output: torch.Tensor) -> torch.Tensor:
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class AsyncTPPass(VllmInductorPass):

    def __init__(self, config: VllmConfig):
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="async_tp_pass")
        GEMMReduceScatterPattern(self.model_dtype,
                                 self.device).register(self.patterns)

        AllGatherGEMMPattern(self.model_dtype,
                             self.device).register(self.patterns)

364
365
366
367
368
369
370
371
372
373
374
375
376
377
        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
            ScaledMMReduceScatterPattern(self.model_dtype,
                                         self.device).register(self.patterns)
            AllGatherScaledMMPattern(self.model_dtype,
                                     self.device).register(self.patterns)

            CutlassScaledMMReduceScatterPattern(
                self.model_dtype, self.device).register(self.patterns)
            AllGatherCutlassScaledMMPattern(
                self.model_dtype, self.device).register(self.patterns)

378
379
380
381
382
383
384
385
386
    def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
        # only do replace for specific shapes
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

    def __call__(self, graph: fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_async_tp_pass")
        count = self.patterns.apply(graph)
387
        logger.debug("Replaced %s patterns with async TP pass.", count)
388
389
        self.dump_graph(graph, "after_async_tp_pass")
        self.end_and_log()
390
391
392
393
394
395
396
397
398


if flashinfer_comm is not None:
    _FI_WORKSPACE_TENSOR = None

    MiB = 1024 * 1024
    # Max size of the input tensor per world size
    # to use flashinfer fused allreduce
    _FI_MAX_SIZES = {
399
        2: 64 * MiB,  # 64MB
400
401
402
403
        4: MiB,  # 1MB
        6: MiB // 2,  # 512KB
        8: MiB // 2,  # 512KB
    }
404
405
406
    # opt for a more conservative default value
    # when world size is not in _FI_MAX_SIZES
    _DEFAULT_FI_MAX_SIZE = MiB // 2
407
408
409
410
411
412
413
414
415
416
417
418

    def call_trtllm_fused_allreduce_norm(
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
419
420
        pattern_code: int,
        fuse_rms_quant: bool,
421
        norm_out: Optional[torch.Tensor] = None,
422
423
424
        quant_out: Optional[torch.Tensor] = None,
        scale_out: Optional[torch.Tensor] = None,
        scale_factor: Optional[torch.Tensor] = None,
425
    ) -> None:
426
427
428
429
430
431
432
433
        num_tokens, hidden_size = allreduce_in.shape
        element_size = allreduce_in.element_size()
        current_tensor_size = num_tokens * hidden_size * element_size
        max_fusion_size = max_token_num * hidden_size * element_size
        use_flashinfer = current_tensor_size <= min(
            _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
            max_fusion_size,
        )
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        if use_flashinfer:
            assert (_FI_WORKSPACE_TENSOR is not None
                    ), "Flashinfer must be enabled when using flashinfer"
            if norm_out is None:
                norm_out = allreduce_in
                residual_out = residual
            else:
                # return residual_out as allreduce_out with zeroed residual_in
                # as flashinfer does not support rms_norm
                # and allreduce_out together
                residual_out = allreduce_in
            # For the sizes that are smaller than the max size,
            # we only use flashinfer one shot allreduce
            flashinfer_comm.trtllm_allreduce_fusion(
                allreduce_in=allreduce_in,
                token_num=allreduce_in.shape[0],
                residual_in=residual,
                residual_out=residual_out,
                norm_out=norm_out,
                rms_gamma=rms_gamma,
                rms_eps=rms_eps,
                world_rank=world_rank,
                world_size=world_size,
                hidden_dim=allreduce_in.shape[-1],
                workspace_ptrs=_FI_WORKSPACE_TENSOR,
                launch_with_pdl=launch_with_pdl,
                use_oneshot=True,
                trigger_completion_at_end=trigger_completion_at_end,
                fp32_acc=fp32_acc,
463
                pattern_code=pattern_code,
464
                allreduce_out=None,
465
466
467
468
469
                quant_out=quant_out,
                scale_out=scale_out,
                # in vllm we only support swizzled layout
                layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
                scale_factor=scale_factor,
470
471
472
            )
        else:
            allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
473
474
475
476
477
478
479
480
481
482
483
            if (scale_factor is not None and scale_out is None
                    and fuse_rms_quant):
                # Do fused rms norm static fp8 quant fused op
                if norm_out is None:
                    torch.ops._C.fused_add_rms_norm_static_fp8_quant(
                        quant_out, allreduce_out, residual, rms_gamma,
                        scale_factor, rms_eps)
                else:
                    torch.ops._C.rms_norm_static_fp8_quant(
                        quant_out, allreduce_out, rms_gamma, scale_factor,
                        rms_eps)
484
            else:
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
                if norm_out is None:
                    torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
                                                    rms_gamma, rms_eps)
                    norm_out = allreduce_out
                else:
                    torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
                                          rms_eps)
                if scale_factor is not None:
                    if scale_out is not None:
                        torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
                                                      scale_out, scale_factor)
                    else:
                        torch.ops._C.static_scaled_fp8_quant(
                            quant_out, norm_out, scale_factor)
            if scale_factor is None or norm_out is not None:
                # we need to return allreduce outpput
                # in cases of non quant fused AR + RMS norm
                # and fused AR + RMS norm + quant without fused add
                allreduce_in.copy_(allreduce_out)
504
505

    def call_trtllm_fused_allreduce_norm_fake(
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
            allreduce_in: torch.Tensor,
            residual: torch.Tensor,
            rms_gamma: torch.Tensor,
            rms_eps: float,
            world_rank: int,
            world_size: int,
            launch_with_pdl: bool,
            trigger_completion_at_end: bool,
            fp32_acc: bool,
            max_token_num: int,
            pattern_code: int,
            fuse_rms_quant: bool,
            norm_out: Optional[torch.Tensor] = None,
            quant_out: Optional[torch.Tensor] = None,
            scale_out: Optional[torch.Tensor] = None,
            scale_factor: Optional[torch.Tensor] = None) -> None:
522
523
524
525
526
527
528
529
530
        pass

    direct_register_custom_op(
        op_name="flashinfer_trtllm_fused_allreduce_norm",
        op_func=call_trtllm_fused_allreduce_norm,
        mutates_args=[
            "allreduce_in",
            "residual",
            "norm_out",
531
532
            "quant_out",
            "scale_out",
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        ],
        fake_impl=call_trtllm_fused_allreduce_norm_fake,
        dispatch_key=current_platform.dispatch_key,
    )
    flashinfer_trtllm_fused_allreduce_norm = (
        torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        rank: int,
        world_size: int,
        use_fp32_lamport: bool = False,
        max_token_num: int = 1024,
550
        fuse_rms_quant: bool = False,
551
552
553
554
555
556
557
558
559
    ):
        self.rank = rank
        self.world_size = world_size
        self.use_fp32_lamport = use_fp32_lamport
        self.trigger_completion_at_end = True
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.use_oneshot = False
        self.max_token_num = max_token_num
560
        self.fuse_rms_quant = fuse_rms_quant
561
562
563
564
565
566
567
568
569

    def get_trtllm_fused_allreduce_kwargs(self):
        return {
            "world_rank": self.rank,
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "trigger_completion_at_end": self.trigger_completion_at_end,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
570
            "fuse_rms_quant": self.fuse_rms_quant,
571
572
573
        }


574
575
576
577
578
579
class AllReduceRMSNormPattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (without residual) 
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self):
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
        rms_result = torch.empty([1, 8, 4],
                                 device=self.device,
                                 dtype=self.dtype)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)

        return [input, rms_result, weight]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(input: torch.Tensor, rms_result: torch.Tensor,
                    weight: torch.Tensor):
605
            allreduce_output = tensor_model_parallel_all_reduce(input)
606
607
608
            rms = auto_functionalized(
                RMS_OP,
                result=rms_result,
609
                input=allreduce_output,
610
611
612
                weight=weight,
                epsilon=self.epsilon,
            )
613
614
            # rms_result, allreduce_output
            return rms[1], allreduce_output
615
616
617
618
619

        def replacement(input: torch.Tensor, rms_result: torch.Tensor,
                        weight: torch.Tensor):
            residual = torch.zeros_like(input)
            allreduce = auto_functionalized(
620
                flashinfer_trtllm_fused_allreduce_norm,
621
622
623
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
624
625
                quant_out=None,
                scale_out=None,
626
627
                rms_gamma=weight,
                rms_eps=self.epsilon,
628
629
                pattern_code=flashinfer_comm.AllReduceFusionPattern.
                kARResidualRMSNorm,
630
631
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
632
            # rms_result, allreduce_in
633
634
635
636
637
638
639
            return allreduce[3], allreduce[1]

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class AllReduceFusedAddRMSNormPattern(BasePattern):
640
641
642
643
644
    """
    This pattern replaces the allreduce + rms norm (with residual) 
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self):
        input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [
            residual,
            input,
            weight,
        ]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(residual: torch.Tensor, input: torch.Tensor,
                    weight: torch.Tensor):
671
            allreduce_output = tensor_model_parallel_all_reduce(input)
672
673
            rms = auto_functionalized(
                RMS_ADD_OP,
674
                input=allreduce_output,
675
676
677
678
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
            )
679
            # input, residual
680
681
682
683
684
            return rms[1], rms[2]

        def replacement(residual: torch.Tensor, input: torch.Tensor,
                        weight: torch.Tensor):
            allreduce = auto_functionalized(
685
                flashinfer_trtllm_fused_allreduce_norm,
686
687
                allreduce_in=input,
                residual=residual,
688
689
690
                norm_out=None,
                quant_out=None,
                scale_out=None,
691
692
                rms_gamma=weight,
                rms_eps=self.epsilon,
693
694
                pattern_code=flashinfer_comm.AllReduceFusionPattern.
                kARResidualRMSNorm,
695
696
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
697
            # allreduce_in, residual
698
699
700
701
702
703
            return allreduce[1], allreduce[2]

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (without residual) 
    + static fp8 quant with fused flashinfer implementation.
    Applies to allreduce + rmsnorm + quant before attn 
    in the first Transformer block.
    """

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 allreduce_params: FlashInferFusedAllReduceParams):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

    def register(self, pm_pass: PatternMatcherPass):

        def get_inputs():
            input = torch.zeros([1, 8, 4],
                                device=self.device,
                                dtype=self.dtype)
            rmsnorm_result = torch.empty([1, 8, 4],
                                         device=self.device,
                                         dtype=self.dtype)
            quant_result = torch.empty([1, 8, 4],
                                       device=self.device,
                                       dtype=self.quant_dtype)
            weight = torch.empty([4], device=self.device, dtype=self.dtype)
            scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
            return [input, rmsnorm_result, quant_result, weight, scale]

        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
            rmsnorm_out_tuple = auto_functionalized(RMS_OP,
                                                    result=rmsnorm_result,
                                                    input=all_reduce,
                                                    weight=weight,
                                                    epsilon=self.epsilon)

            quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
                                                  result=quant_result,
                                                  input=rmsnorm_out_tuple[1],
                                                  scale=scale)

            # quant_out, allreduce_output
            return quant_out_tuple[1], all_reduce

        def replacement(input: torch.Tensor, result_rms: torch.Tensor,
                        quant_result: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            residual = torch.zeros_like(input)
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.
                kARResidualRMSNormFP8Quant,  # we don't use norm_out afterwards
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

        pm.register_replacement(pattern, replacement, get_inputs(),
                                pm.fwd_only, pm_pass)


class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn + quant and 
    mlp + rmsnorm + quant before attn.
    """

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 allreduce_params: FlashInferFusedAllReduceParams):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

    def register(self, pm_pass: PatternMatcherPass):

        def get_inputs():
            input = torch.empty([4, 4], device=self.device, dtype=self.dtype)

            residual = torch.empty([4, 4],
                                   device=self.device,
                                   dtype=self.dtype)
            weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
            quant_result = torch.empty([4, 4],
                                       device=self.device,
                                       dtype=self.quant_dtype)
            scale = torch.empty([1, 1],
                                device=self.device,
                                dtype=torch.float32)

            return [
                quant_result,
                residual,
                input,
                weight,
                scale,
            ]

        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            allreduce_output = tensor_model_parallel_all_reduce(input)

            fused_add_rmsnorm_out_tuple = \
            auto_functionalized(
                RMS_ADD_OP,
                input=allreduce_output,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon)
            quant_out_tuple = auto_functionalized(
                STATIC_FP8_QUANT_OP,
                result=quant_result,
                input=fused_add_rmsnorm_out_tuple[1],
                scale=scale)

            # quant_out, allreduce_output
            return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]

        def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
                        input: torch.Tensor, weight: torch.Tensor,
                        scale: torch.Tensor):
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.
                kARResidualRMSNormFP8Quant,  # we don't use norm_out afterwards
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # # quant_out, rms_norm_residual
            return allreduce[4], allreduce[2]

        pm.register_replacement(pattern, replacement, get_inputs(),
                                pm.fwd_only, pm_pass)


class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (without residual) 
    + static nvfp4 quant with fused flashinfer implementation.
    Applies to allreduce + rmsnorm + quant before attn 
    in the first Transformer block.
    """

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 allreduce_params: FlashInferFusedAllReduceParams):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def register(self, pm_pass: PatternMatcherPass):

        def get_inputs():
            input = torch.empty([1, 16, 16],
                                device=self.device,
                                dtype=self.dtype)

            rmsnorm_result = torch.empty([1, 16, 16],
                                         device=self.device,
                                         dtype=self.dtype)
            quant_result = torch.empty((16, 8),
                                       device=self.device,
                                       dtype=torch.uint8)
            input_global_scale = torch.empty([1, 1],
                                             device=self.device,
                                             dtype=torch.float32)
            weight = torch.empty([16], device=self.device, dtype=self.dtype)
            output_scale = torch.empty([128, 4],
                                       device=self.device,
                                       dtype=torch.int32)

            return [
                input, rmsnorm_result, quant_result, weight,
                input_global_scale, output_scale
            ]

        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
            rmsnorm_out_tuple = auto_functionalized(RMS_OP,
                                                    result=rmsnorm_result,
                                                    input=all_reduce,
                                                    weight=weight,
                                                    epsilon=self.epsilon)

            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
                input=rmsnorm_out_tuple[1],
                output_scale=output_scale,
                input_scale=input_global_scale)

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

        def replacement(input: torch.Tensor, result_rms: torch.Tensor,
                        quant_result: torch.Tensor, weight: torch.Tensor,
                        input_global_scale: torch.Tensor,
                        output_scale: torch.Tensor):
            residual = torch.zeros_like(input)
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.
                kARResidualRMSNormFP4Quant,  # we don't use norm_out afterwards
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

        pm.register_replacement(pattern, replacement, get_inputs(),
                                pm.fwd_only, pm_pass)


class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn + quant and 
    mlp + rmsnorm + quant before attn.
    """

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 allreduce_params: FlashInferFusedAllReduceParams):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def register(self, pm_pass: PatternMatcherPass):

        def get_inputs():
            input = torch.empty([16, 16], device=self.device, dtype=self.dtype)

            residual = torch.empty([16, 16],
                                   device=self.device,
                                   dtype=self.dtype)
            weight = torch.empty([16, 16],
                                 device=self.device,
                                 dtype=self.dtype)
            quant_result = torch.empty((16, 8),
                                       device=self.device,
                                       dtype=torch.uint8)
            input_global_scale = torch.empty([1, 1],
                                             device=self.device,
                                             dtype=torch.float32)
            output_scale = torch.empty([128, 4],
                                       device=self.device,
                                       dtype=torch.int32)

            return [
                quant_result,
                residual,
                input,
                output_scale,
                weight,
                input_global_scale,
            ]

        def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
                    input: torch.Tensor, output_scale: torch.Tensor,
                    weight: torch.Tensor, input_global_scale: torch.Tensor):
            allreduce_output = tensor_model_parallel_all_reduce(input)

            fused_add_rmsnorm_out_tuple = \
            auto_functionalized(
                RMS_ADD_OP,
                input=allreduce_output,
                residual=residual,
                weight=weight,
                epsilon=self.epsilon)
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
                input=fused_add_rmsnorm_out_tuple[1],
                output_scale=output_scale,
                input_scale=input_global_scale)

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
                2], quant_out_tuple[2]

        def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
                        input: torch.Tensor, output_scale: torch.Tensor,
                        weight: torch.Tensor,
                        input_global_scale: torch.Tensor):
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.
                kARResidualRMSNormFP4Quant,  # we don't use norm_out afterwards
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

        pm.register_replacement(pattern, replacement, get_inputs(),
                                pm.fwd_only, pm_pass)


1055
1056
class AllReduceFusionPass(VllmInductorPass):

1057
    def __init__(self, config: VllmConfig):
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        super().__init__(config)
        self.disabled = True
        self.tp_size = get_tensor_model_parallel_world_size()
        if self.tp_size <= 1:
            return
        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="all_reduce_fusion_pass")
        if config.model_config is None:
            return
        self.hidden_dim = config.model_config.get_hidden_size()
        self.group = get_tp_group().device_group
        rank = get_tensor_model_parallel_rank()
        use_fp32_lamport = self.model_dtype == torch.float32
        if flashinfer_comm is None:
            logger.warning(
1073
1074
                "Flashinfer is not installed or comm module not found, "
                "skipping allreduce fusion pass")
1075
1076
1077
1078
1079
1080
1081
1082
1083
            return
        # Check if the world size is supported
        if self.tp_size not in _FI_MAX_SIZES:
            logger.warning(
                "Flashinfer allreduce fusion is not "
                "supported for world size %s",
                self.tp_size,
            )
            return
1084
1085
1086
1087
1088
        max_num_token = min(
            _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
            (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
            config.compilation_config.pass_config.
            fi_allreduce_fusion_max_token_num)
1089
1090
1091
1092
        self.ipc_handles, workspace_tensor = (
            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
                tp_rank=rank,
                tp_size=self.tp_size,
1093
                max_token_num=max_num_token,
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
                hidden_dim=self.hidden_dim,
                group=self.group,
                use_fp32_lamport=use_fp32_lamport,
            ))

        global _FI_WORKSPACE_TENSOR
        _FI_WORKSPACE_TENSOR = workspace_tensor
        self.allreduce_params = FlashInferFusedAllReduceParams(
            rank=rank,
            world_size=self.tp_size,
            use_fp32_lamport=use_fp32_lamport,
1105
1106
1107
1108
            max_token_num=max_num_token,
            # fuse rms norm static fp8 quant fused op
            # in fallback path, when we don't use flashinfer
            fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
1109
1110

        for epsilon in [1e-5, 1e-6]:
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
            AllReduceFusedRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            if current_platform.has_device_capability(100):
                AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
                AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
            AllReduceRMSNormPattern(
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormPattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)

1149
1150
1151
1152
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        self.disabled = False

    def __call__(self, graph: fx.Graph):
        if self.disabled:
            return
        self.begin()
        self.dump_graph(graph, "before_all_reduce_fusion_pass")
        count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", count)
        self.dump_graph(graph, "after_all_reduce_fusion_pass")
        self.end_and_log()

    def __del__(self):
        if self.disabled:
            return
        if flashinfer_comm is not None:
1169
            flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
1170
                self.ipc_handles, self.group)