collective_fusion.py 44.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from importlib.util import find_spec
4
5
6
7

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
11
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
15
from vllm.distributed.parallel_state import (
16
17
18
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
from vllm.utils import direct_register_custom_op
22

23
from .inductor_pass import enable_fake_mode
24
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
25

26
27
FP8_DTYPE = current_platform.fp8_dtype()

28
if find_spec("flashinfer"):
29
30
    try:
        import flashinfer.comm as flashinfer_comm
31
32
33
34
35
36

        flashinfer_comm = (
            flashinfer_comm
            if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
            else None
        )
37
38
    except ImportError:
        flashinfer_comm = None
39
40
41
else:
    flashinfer_comm = None

42
43
logger = init_logger(__name__)

44
45
46
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
47
48
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

class BasePattern:
    def __init__(self, dtype: torch.dtype, device: str):
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()


class GEMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
72
73
                group_name=self.tp.unique_name,
            )
74
75
76
77
78
79
80
81
82
83
84
85
86
            return reduce_scatter

        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

87
88
89
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107


class AllGatherGEMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
108
109
                group_name=self.tp.unique_name,
            )
110
111
112
113

            return torch.ops.aten.mm.default(all_gather, weight)

        def replacement(
114
115
            x: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
116
117
118
119
120
121
122
123
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

124
125
126
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
127
128


129
130
131
class ScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
132
133
134
135
136
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
137
138
139
140
141
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        def pattern(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(
                input,
                mat2=mat2,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )
157
158
159
160
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
161
162
                group_name=self.tp.unique_name,
            )
163
164
            return reduce_scatter

165
166
167
168
169
170
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
171
172
173
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
174
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
175
176
177
178
179
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
180
181
182
183
184
185
186
187
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
188
189
190
191
            )

            return gemm_rs

192
193
194
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
195
196
197
198
199


class AllGatherScaledMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
200
201
202
203
204
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            return torch.ops.aten._scaled_mm.default(
                all_gather,
                mat2=weight,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

254
255
256
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
257
258
259
260
261


class CutlassScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
262
263
264
265
266
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
267
268
269
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

270
        cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
271
272
273
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

    def register(self, pm_pass: PatternMatcherPass):
274
275
276
277
278
279
280
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
281
282
283
284
285
286
287
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
288
289
                bias=None,
            )
290
291
292
293
294

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
295
296
                group_name=self.tp.unique_name,
            )
297
298
            return reduce_scatter

299
300
301
302
303
304
305
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
306
307
308
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
309
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
310
311
312
313
314
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
315
316
317
318
319
320
321
322
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
323
324
325
326
            )

            return gemm_rs

327
328
329
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
330
331
332
333
334


class AllGatherCutlassScaledMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
335
336
337
338
339
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
360
361
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )
362
363
364
365
366
367
368
369

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
370
371
                bias=None,
            )
372
373
            return cutlass_scaled_mm[1]

374
375
376
377
378
379
380
        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

395
396
397
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
398
399


400
class AsyncTPPass(VllmPatternMatcherPass):
401
    @enable_fake_mode
402
403
404
405
406
407
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
408
409
410
            pass_name="async_tp_pass"
        )
        GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
411

412
        AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
413

414
415
416
417
        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
418
419
420
421
422
423
            ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
424

425
426
427
428
429
430
            CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
431

432
433
        self.dump_patterns(config, self.patterns)

434
    def is_applicable_for_shape(self, shape: int | None) -> bool:
435
436
437
438
        # only do replace for specific shapes
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

439
    @VllmInductorPass.time_and_log
440
    def __call__(self, graph: fx.Graph):
441
442
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
443
444
445
446
447
448
449
450
451


if flashinfer_comm is not None:
    _FI_WORKSPACE_TENSOR = None

    MiB = 1024 * 1024
    # Max size of the input tensor per world size
    # to use flashinfer fused allreduce
    _FI_MAX_SIZES = {
452
        2: 64 * MiB,  # 64MB
453
454
455
456
        4: MiB,  # 1MB
        6: MiB // 2,  # 512KB
        8: MiB // 2,  # 512KB
    }
457
458

    try:
459
460
461
462
463
464
        _FI_MAX_SIZES.update(
            {
                int(k): int(float(v) * MiB)
                for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
            }
        )
465
466
    except Exception as e:
        raise ValueError(
467
468
            "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e)
        ) from e
469

470
471
472
    # opt for a more conservative default value
    # when world size is not in _FI_MAX_SIZES
    _DEFAULT_FI_MAX_SIZE = MiB // 2
473
474
475
476
477
478
479
480
481
482
483
484

    def call_trtllm_fused_allreduce_norm(
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
485
486
        pattern_code: int,
        fuse_rms_quant: bool,
487
488
489
490
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
491
    ) -> None:
492
493
494
495
496
497
498
499
        num_tokens, hidden_size = allreduce_in.shape
        element_size = allreduce_in.element_size()
        current_tensor_size = num_tokens * hidden_size * element_size
        max_fusion_size = max_token_num * hidden_size * element_size
        use_flashinfer = current_tensor_size <= min(
            _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
            max_fusion_size,
        )
500
        if use_flashinfer:
501
502
503
            assert _FI_WORKSPACE_TENSOR is not None, (
                "Flashinfer must be enabled when using flashinfer"
            )
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
            if norm_out is None:
                norm_out = allreduce_in
                residual_out = residual
            else:
                # return residual_out as allreduce_out with zeroed residual_in
                # as flashinfer does not support rms_norm
                # and allreduce_out together
                residual_out = allreduce_in
            # For the sizes that are smaller than the max size,
            # we only use flashinfer one shot allreduce
            flashinfer_comm.trtllm_allreduce_fusion(
                allreduce_in=allreduce_in,
                token_num=allreduce_in.shape[0],
                residual_in=residual,
                residual_out=residual_out,
                norm_out=norm_out,
                rms_gamma=rms_gamma,
                rms_eps=rms_eps,
                world_rank=world_rank,
                world_size=world_size,
                hidden_dim=allreduce_in.shape[-1],
                workspace_ptrs=_FI_WORKSPACE_TENSOR,
                launch_with_pdl=launch_with_pdl,
                use_oneshot=True,
                trigger_completion_at_end=trigger_completion_at_end,
                fp32_acc=fp32_acc,
530
                pattern_code=pattern_code,
531
                allreduce_out=None,
532
533
534
                quant_out=quant_out,
                scale_out=scale_out,
                # in vllm we only support swizzled layout
535
                layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
536
                scale_factor=scale_factor,
537
538
539
            )
        else:
            allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
540
            if scale_factor is not None and scale_out is None and fuse_rms_quant:
541
542
543
                # Do fused rms norm static fp8 quant fused op
                if norm_out is None:
                    torch.ops._C.fused_add_rms_norm_static_fp8_quant(
544
545
546
547
548
549
550
                        quant_out,
                        allreduce_out,
                        residual,
                        rms_gamma,
                        scale_factor,
                        rms_eps,
                    )
551
552
                else:
                    torch.ops._C.rms_norm_static_fp8_quant(
553
554
                        quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps
                    )
555
            else:
556
                if norm_out is None:
557
558
559
                    torch.ops._C.fused_add_rms_norm(
                        allreduce_out, residual, rms_gamma, rms_eps
                    )
560
561
                    norm_out = allreduce_out
                else:
562
                    torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
563
564
                if scale_factor is not None:
                    if scale_out is not None:
565
566
567
                        torch.ops._C.scaled_fp4_quant(
                            quant_out, norm_out, scale_out, scale_factor
                        )
568
569
                    else:
                        torch.ops._C.static_scaled_fp8_quant(
570
571
                            quant_out, norm_out, scale_factor
                        )
572
            if scale_factor is None or norm_out is not None:
573
                # we need to return allreduce output
574
575
576
                # in cases of non quant fused AR + RMS norm
                # and fused AR + RMS norm + quant without fused add
                allreduce_in.copy_(allreduce_out)
577
578

    def call_trtllm_fused_allreduce_norm_fake(
579
580
581
582
583
584
585
586
587
588
589
590
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
        pattern_code: int,
        fuse_rms_quant: bool,
591
592
593
594
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
595
    ) -> None:
596
597
598
599
600
601
602
603
604
        pass

    direct_register_custom_op(
        op_name="flashinfer_trtllm_fused_allreduce_norm",
        op_func=call_trtllm_fused_allreduce_norm,
        mutates_args=[
            "allreduce_in",
            "residual",
            "norm_out",
605
606
            "quant_out",
            "scale_out",
607
608
609
610
        ],
        fake_impl=call_trtllm_fused_allreduce_norm_fake,
    )
    flashinfer_trtllm_fused_allreduce_norm = (
611
612
        torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
    )
613
614
615
616
617
618
619
620
621
622
623


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        rank: int,
        world_size: int,
        use_fp32_lamport: bool = False,
        max_token_num: int = 1024,
624
        fuse_rms_quant: bool = False,
625
626
627
628
629
630
631
632
633
    ):
        self.rank = rank
        self.world_size = world_size
        self.use_fp32_lamport = use_fp32_lamport
        self.trigger_completion_at_end = True
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.use_oneshot = False
        self.max_token_num = max_token_num
634
        self.fuse_rms_quant = fuse_rms_quant
635
636
637
638
639
640
641
642
643

    def get_trtllm_fused_allreduce_kwargs(self):
        return {
            "world_rank": self.rank,
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "trigger_completion_at_end": self.trigger_completion_at_end,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
644
            "fuse_rms_quant": self.fuse_rms_quant,
645
646
647
        }


648
649
class AllReduceRMSNormPattern(BasePattern):
    """
650
    This pattern replaces the allreduce + rms norm (without residual)
651
652
653
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """
654
655
656
657
658
659
660
661
662
663
664
665
666
667

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self):
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
668
        rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
669
670
671
672
673
        weight = torch.empty([4], device=self.device, dtype=self.dtype)

        return [input, rms_result, weight]

    def register(self, pm_pass: PatternMatcherPass):
674
675
676
        def pattern(
            input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
        ):
677
            allreduce_output = tensor_model_parallel_all_reduce(input)
678
679
680
            rms = auto_functionalized(
                RMS_OP,
                result=rms_result,
681
                input=allreduce_output,
682
683
684
                weight=weight,
                epsilon=self.epsilon,
            )
685
686
            # rms_result, allreduce_output
            return rms[1], allreduce_output
687

688
689
690
        def replacement(
            input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
        ):
691
692
            residual = torch.zeros_like(input)
            allreduce = auto_functionalized(
693
                flashinfer_trtllm_fused_allreduce_norm,
694
695
696
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
697
698
                quant_out=None,
                scale_out=None,
699
700
                rms_gamma=weight,
                rms_eps=self.epsilon,
701
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
702
703
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
704
            # rms_result, allreduce_in
705
706
            return allreduce[3], allreduce[1]

707
708
709
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
710
711
712


class AllReduceFusedAddRMSNormPattern(BasePattern):
713
    """
714
    This pattern replaces the allreduce + rms norm (with residual)
715
716
717
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self):
        input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [
            residual,
            input,
            weight,
        ]

    def register(self, pm_pass: PatternMatcherPass):
741
        def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
742
            allreduce_output = tensor_model_parallel_all_reduce(input)
743
744
            rms = auto_functionalized(
                RMS_ADD_OP,
745
                input=allreduce_output,
746
747
748
749
                residual=residual,
                weight=weight,
                epsilon=self.epsilon,
            )
750
            # input, residual
751
752
            return rms[1], rms[2]

753
754
755
        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ):
756
            allreduce = auto_functionalized(
757
                flashinfer_trtllm_fused_allreduce_norm,
758
759
                allreduce_in=input,
                residual=residual,
760
761
762
                norm_out=None,
                quant_out=None,
                scale_out=None,
763
764
                rms_gamma=weight,
                rms_eps=self.epsilon,
765
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
766
767
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
768
            # allreduce_in, residual
769
770
            return allreduce[1], allreduce[2]

771
772
773
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
774
775


776
777
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
778
    This pattern replaces the allreduce + rms norm (without residual)
779
    + static fp8 quant with fused flashinfer implementation.
780
    Applies to allreduce + rmsnorm + quant before attn
781
782
783
    in the first Transformer block.
    """

784
785
786
787
788
789
790
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
791
792
793
794
795
796
797
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
798
799
800
801
802
803
804
            input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
            rmsnorm_result = torch.empty(
                [1, 8, 4], device=self.device, dtype=self.dtype
            )
            quant_result = torch.empty(
                [1, 8, 4], device=self.device, dtype=self.quant_dtype
            )
805
806
807
808
809
810
811
812
813
814
815
816
            weight = torch.empty([4], device=self.device, dtype=self.dtype)
            scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
            return [input, rmsnorm_result, quant_result, weight, scale]

        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
817
818
819
820
821
822
823
            rmsnorm_out_tuple = auto_functionalized(
                RMS_OP,
                result=rmsnorm_result,
                input=all_reduce,
                weight=weight,
                epsilon=self.epsilon,
            )
824

825
826
827
828
829
830
            quant_out_tuple = auto_functionalized(
                STATIC_FP8_QUANT_OP,
                result=quant_result,
                input=rmsnorm_out_tuple[1],
                scale=scale,
            )
831
832
833
834

            # quant_out, allreduce_output
            return quant_out_tuple[1], all_reduce

835
836
837
838
839
840
841
        def replacement(
            input: torch.Tensor,
            result_rms: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
842
843
844
845
846
847
848
849
850
851
            residual = torch.zeros_like(input)
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
852
853
854
855
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
856
857
858
859
860
861
862
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

863
864
865
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
866
867
868
869
870
871


class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
872
    Applies to o_proj + rmsnorm after attn + quant and
873
874
875
    mlp + rmsnorm + quant before attn.
    """

876
877
878
879
880
881
882
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
883
884
885
886
887
888
889
890
891
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
            input = torch.empty([4, 4], device=self.device, dtype=self.dtype)

892
            residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
893
            weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
894
895
896
897
            quant_result = torch.empty(
                [4, 4], device=self.device, dtype=self.quant_dtype
            )
            scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915

            return [
                quant_result,
                residual,
                input,
                weight,
                scale,
            ]

        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            allreduce_output = tensor_model_parallel_all_reduce(input)

916
            fused_add_rmsnorm_out_tuple = auto_functionalized(
917
918
919
920
                RMS_ADD_OP,
                input=allreduce_output,
                residual=residual,
                weight=weight,
921
922
                epsilon=self.epsilon,
            )
923
924
925
926
            quant_out_tuple = auto_functionalized(
                STATIC_FP8_QUANT_OP,
                result=quant_result,
                input=fused_add_rmsnorm_out_tuple[1],
927
928
                scale=scale,
            )
929
930
931
932

            # quant_out, allreduce_output
            return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]

933
934
935
936
937
938
939
        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
940
941
942
943
944
945
946
947
948
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
949
950
951
952
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
953
954
955
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
956
            # quant_out, rms_norm_residual
957
958
            return allreduce[4], allreduce[2]

959
960
961
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
962
963
964
965


class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
966
    This pattern replaces the allreduce + rms norm (without residual)
967
    + static nvfp4 quant with fused flashinfer implementation.
968
    Applies to allreduce + rmsnorm + quant before attn
969
970
971
    in the first Transformer block.
    """

972
973
974
975
976
977
978
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
979
980
981
982
983
984
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
985
986
987
988
989
990
991
992
993
            input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)

            rmsnorm_result = torch.empty(
                [1, 16, 16], device=self.device, dtype=self.dtype
            )
            quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
            input_global_scale = torch.empty(
                [1, 1], device=self.device, dtype=torch.float32
            )
994
            weight = torch.empty([16], device=self.device, dtype=self.dtype)
995
            output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
996
997

            return [
998
999
1000
1001
1002
1003
                input,
                rmsnorm_result,
                quant_result,
                weight,
                input_global_scale,
                output_scale,
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
            ]

        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
1015
1016
1017
1018
1019
1020
1021
            rmsnorm_out_tuple = auto_functionalized(
                RMS_OP,
                result=rmsnorm_result,
                input=all_reduce,
                weight=weight,
                epsilon=self.epsilon,
            )
1022
1023
1024
1025
1026
1027

            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
                input=rmsnorm_out_tuple[1],
                output_scale=output_scale,
1028
1029
                input_scale=input_global_scale,
            )
1030
1031
1032
1033

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

1034
1035
1036
1037
1038
1039
1040
1041
        def replacement(
            input: torch.Tensor,
            result_rms: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            residual = torch.zeros_like(input)
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
1052
1053
1054
1055
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
1056
1057
1058
1059
1060
1061
1062
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

1063
1064
1065
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
1066
1067
1068
1069
1070
1071


class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
1072
    Applies to o_proj + rmsnorm after attn + quant and
1073
1074
1075
    mlp + rmsnorm + quant before attn.
    """

1076
1077
1078
1079
1080
1081
1082
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
1083
1084
1085
1086
1087
1088
1089
1090
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
            input = torch.empty([16, 16], device=self.device, dtype=self.dtype)

1091
1092
1093
1094
1095
1096
1097
            residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
            weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
            quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
            input_global_scale = torch.empty(
                [1, 1], device=self.device, dtype=torch.float32
            )
            output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

            return [
                quant_result,
                residual,
                input,
                output_scale,
                weight,
                input_global_scale,
            ]

1108
1109
1110
1111
1112
1113
1114
1115
        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ):
1116
1117
            allreduce_output = tensor_model_parallel_all_reduce(input)

1118
            fused_add_rmsnorm_out_tuple = auto_functionalized(
1119
1120
1121
1122
                RMS_ADD_OP,
                input=allreduce_output,
                residual=residual,
                weight=weight,
1123
1124
                epsilon=self.epsilon,
            )
1125
1126
1127
1128
1129
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
                input=fused_add_rmsnorm_out_tuple[1],
                output_scale=output_scale,
1130
1131
                input_scale=input_global_scale,
            )
1132
1133

            # quant_out, allreduce_output, output_scale
1134
1135
1136
1137
1138
            return (
                quant_out_tuple[1],
                fused_add_rmsnorm_out_tuple[2],
                quant_out_tuple[2],
            )
1139

1140
1141
1142
1143
1144
1145
1146
1147
        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ):
1148
1149
1150
1151
1152
1153
1154
1155
1156
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
1157
1158
1159
1160
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
1161
1162
1163
1164
1165
1166
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

1167
1168
1169
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
1170
1171


1172
class AllReduceFusionPass(VllmPatternMatcherPass):
1173
    def __init__(self, config: VllmConfig):
1174
1175
1176
1177
1178
1179
        super().__init__(config)
        self.disabled = True
        self.tp_size = get_tensor_model_parallel_world_size()
        if self.tp_size <= 1:
            return
        self.patterns: PatternMatcherPass = PatternMatcherPass(
1180
1181
            pass_name="all_reduce_fusion_pass"
        )
1182
1183
1184
1185
1186
1187
1188
1189
        if config.model_config is None:
            return
        self.hidden_dim = config.model_config.get_hidden_size()
        self.group = get_tp_group().device_group
        rank = get_tensor_model_parallel_rank()
        use_fp32_lamport = self.model_dtype == torch.float32
        if flashinfer_comm is None:
            logger.warning(
1190
                "Flashinfer is not installed or comm module not found, "
1191
1192
                "skipping allreduce fusion pass"
            )
1193
1194
1195
1196
            return
        # Check if the world size is supported
        if self.tp_size not in _FI_MAX_SIZES:
            logger.warning(
1197
                "Flashinfer allreduce fusion is not supported for world size %s",
1198
1199
1200
                self.tp_size,
            )
            return
1201
        max_num_token = min(
1202
1203
1204
1205
            _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE)
            // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
            config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num,
        )
1206
1207
1208
1209
        self.ipc_handles, workspace_tensor = (
            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
                tp_rank=rank,
                tp_size=self.tp_size,
1210
                max_token_num=max_num_token,
1211
1212
1213
                hidden_dim=self.hidden_dim,
                group=self.group,
                use_fp32_lamport=use_fp32_lamport,
1214
1215
            )
        )
1216
1217
1218
1219
1220
1221
1222

        global _FI_WORKSPACE_TENSOR
        _FI_WORKSPACE_TENSOR = workspace_tensor
        self.allreduce_params = FlashInferFusedAllReduceParams(
            rank=rank,
            world_size=self.tp_size,
            use_fp32_lamport=use_fp32_lamport,
1223
1224
1225
            max_token_num=max_num_token,
            # fuse rms norm static fp8 quant fused op
            # in fallback path, when we don't use flashinfer
1226
1227
            fuse_rms_quant=config.compilation_config.pass_config.enable_fusion,
        )
1228

1229
        self.register_patterns()
1230
        self.dump_patterns(config, self.patterns)
1231
1232
1233

    @enable_fake_mode
    def register_patterns(self):
1234
        for epsilon in [1e-5, 1e-6]:
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
            AllReduceFusedRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            if current_platform.has_device_capability(100):
                AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
                AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
            AllReduceRMSNormPattern(
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormPattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)

1273
1274
1275
1276
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

1277
1278
        self.disabled = False

1279
    @VllmInductorPass.time_and_log
1280
1281
    def __call__(self, graph: fx.Graph):
        if self.disabled:
1282
            logger.debug("AllReduceFusionPass disabled")
1283
            return
1284
1285
1286

        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
1287
1288

    def __del__(self):
1289
        if getattr(self, "disabled", True):
1290
1291
            return
        if flashinfer_comm is not None:
1292
            flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
1293
1294
                self.ipc_handles, self.group
            )