collective_fusion.py 42.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from importlib.util import find_spec
4
5
6
7

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
8
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
10
11
12
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

from vllm.config import VllmConfig
13
from vllm.config.utils import Range
14
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
15
from vllm.distributed.parallel_state import (
16
17
18
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
19
from vllm.logger import init_logger
20
21
22
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
23
from vllm.platforms import current_platform
24
from vllm.utils.torch_utils import direct_register_custom_op
25

26
from .inductor_pass import enable_fake_mode
27
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
28
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
29

30
31
FP8_DTYPE = current_platform.fp8_dtype()

32
if find_spec("flashinfer"):
33
34
    try:
        import flashinfer.comm as flashinfer_comm
35
36
37
38
39
40

        flashinfer_comm = (
            flashinfer_comm
            if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
            else None
        )
41
42
    except ImportError:
        flashinfer_comm = None
43
44
45
else:
    flashinfer_comm = None

46
47
logger = init_logger(__name__)

48
49
if hasattr(torch.ops._C, "scaled_fp4_quant"):
    STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

class BasePattern:
    def __init__(self, dtype: torch.dtype, device: str):
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()


class GEMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
73
74
                group_name=self.tp.unique_name,
            )
75
76
77
78
79
80
81
82
83
84
85
86
87
            return reduce_scatter

        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

88
89
90
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108


class AllGatherGEMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
109
110
                group_name=self.tp.unique_name,
            )
111
112
113
114

            return torch.ops.aten.mm.default(all_gather, weight)

        def replacement(
115
116
            x: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
117
118
119
120
121
122
123
124
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

125
126
127
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
128
129


130
131
132
class ScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
133
134
135
136
137
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
138
139
140
141
142
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        def pattern(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(
                input,
                mat2=mat2,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )
158
159
160
161
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
162
163
                group_name=self.tp.unique_name,
            )
164
165
            return reduce_scatter

166
167
168
169
170
171
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
172
173
174
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
175
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
176
177
178
179
180
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
181
182
183
184
185
186
187
188
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
189
190
191
192
            )

            return gemm_rs

193
194
195
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
196
197
198
199
200


class AllGatherScaledMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
201
202
203
204
205
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            return torch.ops.aten._scaled_mm.default(
                all_gather,
                mat2=weight,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

255
256
257
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
258
259
260
261
262


class CutlassScaledMMReduceScatterPattern(BasePattern):
    def get_inputs(self):
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
263
264
265
266
267
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
268
269
270
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

271
        cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
272
273
274
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

    def register(self, pm_pass: PatternMatcherPass):
275
276
277
278
279
280
281
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
282
283
284
285
286
287
288
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
289
290
                bias=None,
            )
291
292
293
294
295

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
296
297
                group_name=self.tp.unique_name,
            )
298
299
            return reduce_scatter

300
301
302
303
304
305
306
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
307
308
309
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
310
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
311
312
313
314
315
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
316
317
318
319
320
321
322
323
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
324
325
326
327
            )

            return gemm_rs

328
329
330
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
331
332
333
334
335


class AllGatherCutlassScaledMMPattern(BasePattern):
    def get_inputs(self):
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
336
337
338
339
340
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
361
362
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )
363
364
365
366
367
368
369
370

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
371
372
                bias=None,
            )
373
374
            return cutlass_scaled_mm[1]

375
376
377
378
379
380
381
        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
382
383
384
385
386
387
388
389
390
391
392
393
394
395
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

396
397
398
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
399
400


401
class AsyncTPPass(VllmPatternMatcherPass):
402
    @enable_fake_mode
403
404
405
406
407
408
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
409
410
411
            pass_name="async_tp_pass"
        )
        GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
412

413
        AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
414

415
416
417
418
        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
419
420
421
422
423
424
            ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
425

426
427
428
429
430
431
            CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
432

433
434
        self.dump_patterns(config, self.patterns)

435
    def is_applicable_for_range(self, compile_range: Range) -> bool:
436
437
438
439
440
441
442
443
        # This pass is applied on top of the sequence parallelism pass.
        # It inherits the same applicability condition as `SequenceParallelismPass`.
        # See `SequenceParallelismPass.is_applicable` for more details.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
444
        tp_size = get_tensor_model_parallel_world_size()
445
        return compile_range.is_single_size() and compile_range.end % tp_size == 0
446

447
    @VllmInductorPass.time_and_log
448
    def __call__(self, graph: fx.Graph):
449
450
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
451
452


453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
    90: {
        2: 64,  # 64MB
        4: 2,  # 2MB
        8: 0.5,  # 0.5MB
    },
    100: {
        2: 64,  # 64MB
        4: 32,  # 32MB
        8: 1,  # 1MB
    },
}

# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
    90: {
        2: 32,  # 32MB
        4: 2,  # 2MB
        8: 0.5,  # 0.5MB
    },
    100: {
        2: 32,  # 32MB
        4: 4,  # 4MB
        8: 1,  # 1MB
    },
}


485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
if flashinfer_comm is not None:
    _FI_WORKSPACE_TENSOR = None
    MiB = 1024 * 1024

    def call_trtllm_fused_allreduce_norm(
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
500
        pattern_code: int,
501
502
503
504
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
505
    ) -> None:
506
507
508
        num_tokens, hidden_size = allreduce_in.shape
        element_size = allreduce_in.element_size()
        current_tensor_size = num_tokens * hidden_size * element_size
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        max_tensor_size = max_token_num * hidden_size * element_size
        assert current_tensor_size <= max_tensor_size, (
            f"Current tensor size {current_tensor_size} is larger than "
            f"max token num {max_token_num} * hidden size {hidden_size} * "
            f"element size {element_size}"
        )
        device_capability = current_platform.get_device_capability().to_int()
        # Get one shot input size limit for the current world size
        # for the current device capability
        max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
            device_capability, {}
        ).get(world_size, None)
        # Use one shot if no max size is specified
        use_oneshot = (
            max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
        )
525

526
527
528
529
530
531
        assert _FI_WORKSPACE_TENSOR is not None, (
            "Flashinfer must be enabled when using flashinfer"
        )
        if norm_out is None:
            norm_out = allreduce_in
            residual_out = residual
532
        else:
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
            # return residual_out as allreduce_out with zeroed residual_in
            # as flashinfer does not support rms_norm
            # and allreduce_out together
            residual_out = allreduce_in
        # For the sizes that are smaller than the max size,
        # we only use flashinfer one shot allreduce
        flashinfer_comm.trtllm_allreduce_fusion(
            allreduce_in=allreduce_in,
            token_num=allreduce_in.shape[0],
            residual_in=residual,
            residual_out=residual_out,
            norm_out=norm_out,
            rms_gamma=rms_gamma,
            rms_eps=rms_eps,
            world_rank=world_rank,
            world_size=world_size,
            hidden_dim=allreduce_in.shape[-1],
            workspace_ptrs=_FI_WORKSPACE_TENSOR,
            launch_with_pdl=launch_with_pdl,
            use_oneshot=use_oneshot,
            trigger_completion_at_end=trigger_completion_at_end,
            fp32_acc=fp32_acc,
            pattern_code=pattern_code,
            allreduce_out=None,
            quant_out=quant_out,
            scale_out=scale_out,
            # in vllm we only support swizzled layout
            layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
            scale_factor=scale_factor,
        )
563
564

    def call_trtllm_fused_allreduce_norm_fake(
565
566
567
568
569
570
571
572
573
574
575
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
        pattern_code: int,
576
577
578
579
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
580
    ) -> None:
581
582
583
584
585
586
587
588
589
        pass

    direct_register_custom_op(
        op_name="flashinfer_trtllm_fused_allreduce_norm",
        op_func=call_trtllm_fused_allreduce_norm,
        mutates_args=[
            "allreduce_in",
            "residual",
            "norm_out",
590
591
            "quant_out",
            "scale_out",
592
593
594
595
        ],
        fake_impl=call_trtllm_fused_allreduce_norm_fake,
    )
    flashinfer_trtllm_fused_allreduce_norm = (
596
597
        torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
    )
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        rank: int,
        world_size: int,
        use_fp32_lamport: bool = False,
        max_token_num: int = 1024,
    ):
        self.rank = rank
        self.world_size = world_size
        self.use_fp32_lamport = use_fp32_lamport
        self.trigger_completion_at_end = True
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.max_token_num = max_token_num

    def get_trtllm_fused_allreduce_kwargs(self):
        return {
            "world_rank": self.rank,
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "trigger_completion_at_end": self.trigger_completion_at_end,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
        }


629
630
class AllReduceRMSNormPattern(BasePattern):
    """
631
    This pattern replaces the allreduce + rms norm (without residual)
632
633
634
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """
635
636
637
638
639
640
641
642
643
644
645

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
646
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
647
648

    def get_inputs(self):
649
        input, weight = self.rmsnorm_matcher.inputs()
650

651
652
        # input goes through allreduce first, always 16-bit
        return [input.to(self.dtype), weight]
653
654

    def register(self, pm_pass: PatternMatcherPass):
655
        def pattern(input: torch.Tensor, weight: torch.Tensor):
656
            allreduce_output = tensor_model_parallel_all_reduce(input)
657
            rms = self.rmsnorm_matcher(allreduce_output, weight)
658

659
660
661
            return rms, allreduce_output

        def replacement(input: torch.Tensor, weight: torch.Tensor):
662
            residual = torch.zeros_like(input)
663
            rms_result = torch.empty_like(input)
664
            allreduce = auto_functionalized(
665
                flashinfer_trtllm_fused_allreduce_norm,
666
667
668
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
669
670
                quant_out=None,
                scale_out=None,
671
672
                rms_gamma=weight,
                rms_eps=self.epsilon,
673
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
674
675
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
676
            # rms_result, allreduce_in
677
678
            return allreduce[3], allreduce[1]

679
680
681
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
682
683
684


class AllReduceFusedAddRMSNormPattern(BasePattern):
685
    """
686
    This pattern replaces the allreduce + rms norm (with residual)
687
688
689
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """
690
691
692
693
694
695
696
697
698
699
700

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
701
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
702
703

    def get_inputs(self):
704
705
706
707
        input, residual, weight = self.rmsnorm_matcher.inputs()

        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight]
708
709

    def register(self, pm_pass: PatternMatcherPass):
710
        def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
711
            allreduce_output = tensor_model_parallel_all_reduce(input)
712
713
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
            return rms, residual
714

715
716
717
        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ):
718
            allreduce = auto_functionalized(
719
                flashinfer_trtllm_fused_allreduce_norm,
720
721
                allreduce_in=input,
                residual=residual,
722
723
724
                norm_out=None,
                quant_out=None,
                scale_out=None,
725
726
                rms_gamma=weight,
                rms_eps=self.epsilon,
727
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
728
729
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
730
            # allreduce_in, residual
731
732
            return allreduce[1], allreduce[2]

733
734
735
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
736

737
738
739
740
741
742
743
744
745
746
747
748
        # Same pattern, but only return the output and not residual
        # (helpful for end of graph where residual is not used again)
        first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]

        pm.register_replacement(
            first_return_only(pattern),
            first_return_only(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
        )

749

750
751
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
752
    This pattern replaces the allreduce + rms norm (without residual)
753
    + static fp8 quant with fused flashinfer implementation.
754
    Applies to allreduce + rmsnorm + quant before attn
755
756
757
    in the first Transformer block.
    """

758
759
760
761
762
763
764
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
765
766
767
768
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn
769
770
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
771
772
773

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
774
775
776
777
778
            input, weight = self.rmsnorm_matcher.inputs()
            _, scale = self.quant_matcher.inputs()

            # input goes through allreduce first, always 16-bit
            return [input.to(self.dtype), weight, scale]
779
780
781
782
783
784
785

        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
786
787
788
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
789

790
        def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
791
            residual = torch.zeros_like(input)
792
793
            result_rms = torch.empty_like(input)
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
794
795
796
797
798
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
799
                quant_out=result_quant,
800
801
802
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
803
804
805
806
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
807
808
809
810
811
812
813
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

814
815
816
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
817
818
819
820
821
822


class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
823
    Applies to o_proj + rmsnorm after attn + quant and
824
825
826
    mlp + rmsnorm + quant before attn.
    """

827
828
829
830
831
832
833
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
834
835
836
837
838
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

839
840
841
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

842
843
    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
844
845
            input, residual, weight = self.rmsnorm_matcher.inputs()
            _, scale = self.quant_matcher.inputs()
846

847
848
            # input goes through allreduce first, always 16-bit
            return [residual, input.to(self.dtype), weight, scale]
849
850
851
852
853
854
855
856

        def pattern(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            allreduce_output = tensor_model_parallel_all_reduce(input)
857
858
            rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
            quant, _ = self.quant_matcher(rms, scale)
859

860
            return quant, res
861

862
863
864
865
866
867
        def replacement(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
868
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
869
870
871
872
873
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
874
                quant_out=result_quant,
875
876
877
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
878
879
880
881
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
882
883
884
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
885
            # quant_out, rms_norm_residual
886
887
            return allreduce[4], allreduce[2]

888
889
890
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
891
892
893
894


class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
895
    This pattern replaces the allreduce + rms norm (without residual)
896
    + static nvfp4 quant with fused flashinfer implementation.
897
    Applies to allreduce + rmsnorm + quant before attn
898
899
900
    in the first Transformer block.
    """

901
902
903
904
905
906
907
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
908
909
910
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
911
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
912
913
914

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
915
916
917
918
919
            input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
            quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
            input_global_scale = torch.empty(
                [1, 1], device=self.device, dtype=torch.float32
            )
920
            weight = torch.empty([16], device=self.device, dtype=self.dtype)
921
            output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
922

923
            return [input, quant_result, weight, input_global_scale, output_scale]
924
925
926
927
928
929
930
931
932

        def pattern(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
            all_reduce = tensor_model_parallel_all_reduce(input)
933
            rms = self.rmsnorm_matcher(all_reduce, weight)
934
935
936
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
937
                input=rms,
938
                output_scale=output_scale,
939
940
                input_scale=input_global_scale,
            )
941
942
943
944

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

945
946
947
948
949
950
951
        def replacement(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ):
952
            residual = torch.zeros_like(input)
953
            result_rms = torch.empty_like(input)
954
955
956
957
958
959
960
961
962
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
963
964
965
966
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
967
968
969
970
971
972
973
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

974
975
976
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
977
978
979
980
981
982


class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
983
    Applies to o_proj + rmsnorm after attn + quant and
984
985
986
    mlp + rmsnorm + quant before attn.
    """

987
988
989
990
991
992
993
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        allreduce_params: FlashInferFusedAllReduceParams,
    ):
994
995
996
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
997
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
998
999
1000
1001
1002

    def register(self, pm_pass: PatternMatcherPass):
        def get_inputs():
            input = torch.empty([16, 16], device=self.device, dtype=self.dtype)

1003
1004
1005
1006
1007
1008
1009
            residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
            weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
            quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
            input_global_scale = torch.empty(
                [1, 1], device=self.device, dtype=torch.float32
            )
            output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019

            return [
                quant_result,
                residual,
                input,
                output_scale,
                weight,
                input_global_scale,
            ]

1020
1021
1022
1023
1024
1025
1026
1027
        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ):
1028
            allreduce_output = tensor_model_parallel_all_reduce(input)
1029
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
1030
1031
1032
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
1033
                input=rms,
1034
                output_scale=output_scale,
1035
1036
                input_scale=input_global_scale,
            )
1037
1038

            # quant_out, allreduce_output, output_scale
1039
            return quant_out_tuple[1], residual, quant_out_tuple[2]
1040

1041
1042
1043
1044
1045
1046
1047
1048
        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ):
1049
1050
1051
1052
1053
1054
1055
1056
1057
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
1058
1059
1060
1061
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
1062
1063
1064
1065
1066
1067
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

1068
1069
1070
        pm.register_replacement(
            pattern, replacement, get_inputs(), pm.fwd_only, pm_pass
        )
1071
1072


1073
class AllReduceFusionPass(VllmPatternMatcherPass):
1074
    def __init__(self, config: VllmConfig):
1075
1076
1077
1078
        super().__init__(config)
        self.disabled = True
        self.tp_size = get_tensor_model_parallel_world_size()
        if self.tp_size <= 1:
1079
            logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
1080
1081
            return
        self.patterns: PatternMatcherPass = PatternMatcherPass(
1082
1083
            pass_name="all_reduce_fusion_pass"
        )
1084
        if config.model_config is None:
1085
1086
1087
            logger.warning_once(
                "AllReduce fusion pass is disabled for missing model_config."
            )
1088
1089
1090
1091
1092
1093
1094
            return
        self.hidden_dim = config.model_config.get_hidden_size()
        self.group = get_tp_group().device_group
        rank = get_tensor_model_parallel_rank()
        use_fp32_lamport = self.model_dtype == torch.float32
        if flashinfer_comm is None:
            logger.warning(
1095
                "Flashinfer is not installed or comm module not found, "
1096
1097
                "skipping allreduce fusion pass"
            )
1098
            return
1099
1100
1101
1102
1103
        max_size = config.compilation_config.pass_config.flashinfer_max_size(
            self.tp_size
        )
        if max_size is None:
            # Flashinfer doesn't support current world size
1104
            logger.warning(
1105
1106
                "Flashinfer allreduce fusion is not supported for world size %s"
                " or max size is not provided",
1107
1108
1109
                self.tp_size,
            )
            return
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        element_size = 4 if use_fp32_lamport else 2
        self.max_token_num = max_size // (self.hidden_dim * element_size)
        # take the min to save workspace size and we'll never use more
        # than max_num_batched_tokens anyways
        self.max_token_num = min(
            self.max_token_num, config.scheduler_config.max_num_batched_tokens
        )
        logger.debug_once(
            f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
            "Maximal number of tokens used by "
            f"Flashinfer Allreduce Fusion: {self.max_token_num}",
            scope="global",
1122
        )
1123

1124
1125
1126
1127
        self.ipc_handles, workspace_tensor = (
            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
                tp_rank=rank,
                tp_size=self.tp_size,
1128
                max_token_num=self.max_token_num,
1129
1130
1131
                hidden_dim=self.hidden_dim,
                group=self.group,
                use_fp32_lamport=use_fp32_lamport,
1132
1133
            )
        )
1134
1135
1136
1137
1138
1139
1140

        global _FI_WORKSPACE_TENSOR
        _FI_WORKSPACE_TENSOR = workspace_tensor
        self.allreduce_params = FlashInferFusedAllReduceParams(
            rank=rank,
            world_size=self.tp_size,
            use_fp32_lamport=use_fp32_lamport,
1141
            max_token_num=self.max_token_num,
1142
        )
1143

1144
        self.register_patterns()
1145
        self.dump_patterns(config, self.patterns)
1146
1147
1148

    @enable_fake_mode
    def register_patterns(self):
1149
        for epsilon in [1e-5, 1e-6]:
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
            AllReduceFusedRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            if current_platform.has_device_capability(100):
                AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
                AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
            AllReduceRMSNormPattern(
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormPattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)

1188
1189
1190
1191
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

1192
1193
        self.disabled = False

1194
    def is_applicable_for_range(self, compile_range: Range) -> bool:
1195
1196
1197
        if self.disabled:
            logger.warning_once("AllReduce fusion pass is disabled.")
            return False
1198
1199
        return compile_range.end <= self.max_token_num

1200
    @VllmInductorPass.time_and_log
1201
1202
    def __call__(self, graph: fx.Graph):
        if self.disabled:
1203
            logger.debug("AllReduceFusionPass disabled")
1204
            return
1205
1206
1207

        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
1208
1209

    def __del__(self):
1210
        if getattr(self, "disabled", True):
1211
1212
            return
        if flashinfer_comm is not None:
1213
            flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
1214
1215
                self.ipc_handles, self.group
            )