collective_fusion.py 44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from importlib.util import find_spec
4
from types import ModuleType
5
6
7
8

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
9
from torch._higher_order_ops.auto_functionalize import auto_functionalized
10
11
12
13
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

from vllm.config import VllmConfig
14
from vllm.config.utils import Range
15
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
16
from vllm.distributed.parallel_state import (
17
18
19
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
20
from vllm.logger import init_logger
21
22
23
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
24
from vllm.platforms import current_platform
25
from vllm.utils.torch_utils import direct_register_custom_op
26

27
from .inductor_pass import enable_fake_mode
28
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
29
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
30

31
32
FP8_DTYPE = current_platform.fp8_dtype()

33
flashinfer_comm: ModuleType | None = None
34
if find_spec("flashinfer"):
35
    try:
36
        import flashinfer.comm as _flashinfer_comm
37

38
39
        if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"):
            flashinfer_comm = _flashinfer_comm
40
    except ImportError:
41
        pass
42

43
44
logger = init_logger(__name__)

45
46
if hasattr(torch.ops._C, "scaled_fp4_quant"):
    STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
47

48
49

class BasePattern:
50
    def __init__(self, dtype: torch.dtype, device: str | None) -> None:
51
52
53
54
55
56
57
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()


class GEMMReduceScatterPattern(BasePattern):
58
    def get_inputs(self) -> list[torch.Tensor]:
59
60
61
62
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

63
64
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
65
66
67
68
69
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
70
71
                group_name=self.tp.unique_name,
            )
72
73
            return reduce_scatter

74
        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
75
76
77
78
79
80
81
82
83
84
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

85
86
87
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
88
89
90


class AllGatherGEMMPattern(BasePattern):
91
    def get_inputs(self) -> list[torch.Tensor]:
92
93
94
95
96
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

97
    def register(self, pm_pass: PatternMatcherPass) -> None:
98
99
100
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
101
        ) -> torch.Tensor:
102
103
104
105
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
106
107
                group_name=self.tp.unique_name,
            )
108
109
110

            return torch.ops.aten.mm.default(all_gather, weight)

111
        def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
112
113
114
115
116
117
118
119
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

120
121
122
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
123
124


125
class ScaledMMReduceScatterPattern(BasePattern):
126
    def get_inputs(self) -> list[torch.Tensor]:
127
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
128
129
130
131
132
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
133
134
135
136
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

137
    def register(self, pm_pass: PatternMatcherPass) -> None:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        def pattern(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(
                input,
                mat2=mat2,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )
153
154
155
156
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
157
158
                group_name=self.tp.unique_name,
            )
159
160
            return reduce_scatter

161
162
163
164
165
166
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
167
168
169
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
170
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
171
172
173
174
175
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
176
177
178
179
180
181
182
183
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
184
185
186
187
            )

            return gemm_rs

188
189
190
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
191
192
193


class AllGatherScaledMMPattern(BasePattern):
194
    def get_inputs(self) -> list[torch.Tensor]:
195
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
196
197
198
199
200
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
201
202
203
204
205
206
207
208

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

209
    def register(self, pm_pass: PatternMatcherPass) -> None:
210
211
212
213
214
215
216
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            return torch.ops.aten._scaled_mm.default(
                all_gather,
                mat2=weight,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

250
251
252
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
253
254
255


class CutlassScaledMMReduceScatterPattern(BasePattern):
256
    def get_inputs(self) -> list[torch.Tensor]:
257
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
258
259
260
261
262
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
263
264
265
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

266
        cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
267
268
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

269
    def register(self, pm_pass: PatternMatcherPass) -> None:
270
271
272
273
274
275
276
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
277
278
279
280
281
282
283
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
284
285
                bias=None,
            )
286
287
288
289
290

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
291
292
                group_name=self.tp.unique_name,
            )
293
294
            return reduce_scatter

295
296
297
298
299
300
301
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
302
303
304
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
305
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
306
307
308
309
310
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
311
312
313
314
315
316
317
318
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
319
320
321
322
            )

            return gemm_rs

323
324
325
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
326
327
328


class AllGatherCutlassScaledMMPattern(BasePattern):
329
    def get_inputs(self) -> list[torch.Tensor]:
330
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
331
332
333
334
335
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
336
337
338
339
340
341
342
343
344
345
346

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

347
    def register(self, pm_pass: PatternMatcherPass) -> None:
348
349
350
351
352
353
354
355
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
356
357
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )
358
359
360
361
362
363
364
365

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
366
367
                bias=None,
            )
368
369
            return cutlass_scaled_mm[1]

370
371
372
373
374
375
376
        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
377
378
379
380
381
382
383
384
385
386
387
388
389
390
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

391
392
393
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
394
395


396
class AsyncTPPass(VllmPatternMatcherPass):
397
    @enable_fake_mode
398
    def __init__(self, config: VllmConfig) -> None:
399
400
401
402
403
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
404
405
406
            pass_name="async_tp_pass"
        )
        GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
407

408
        AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
409

410
411
412
413
        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
414
415
416
417
418
419
            ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
420

421
422
423
424
425
426
            CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
427

428
429
        self.dump_patterns(config, self.patterns)

430
    def is_applicable_for_range(self, compile_range: Range) -> bool:
431
432
433
434
435
436
437
438
        # This pass is applied on top of the sequence parallelism pass.
        # It inherits the same applicability condition as `SequenceParallelismPass`.
        # See `SequenceParallelismPass.is_applicable` for more details.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
439
        tp_size = get_tensor_model_parallel_world_size()
440
        return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
441

442
    @VllmInductorPass.time_and_log
443
    def __call__(self, graph: fx.Graph) -> None:
444
445
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
446
447


448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
    90: {
        2: 64,  # 64MB
        4: 2,  # 2MB
        8: 0.5,  # 0.5MB
    },
    100: {
        2: 64,  # 64MB
        4: 32,  # 32MB
        8: 1,  # 1MB
    },
}

# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
    90: {
        2: 32,  # 32MB
        4: 2,  # 2MB
        8: 0.5,  # 0.5MB
    },
    100: {
        2: 32,  # 32MB
        4: 4,  # 4MB
        8: 1,  # 1MB
    },
}


480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
if flashinfer_comm is not None:
    _FI_WORKSPACE_TENSOR = None
    MiB = 1024 * 1024

    def call_trtllm_fused_allreduce_norm(
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
495
        pattern_code: int,
496
497
498
499
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
500
    ) -> None:
501
502
503
        num_tokens, hidden_size = allreduce_in.shape
        element_size = allreduce_in.element_size()
        current_tensor_size = num_tokens * hidden_size * element_size
504
505
506
507
508
509
        max_tensor_size = max_token_num * hidden_size * element_size
        assert current_tensor_size <= max_tensor_size, (
            f"Current tensor size {current_tensor_size} is larger than "
            f"max token num {max_token_num} * hidden size {hidden_size} * "
            f"element size {element_size}"
        )
510
511
        curr_device = current_platform.get_device_capability()
        device_capability = curr_device.to_int() if curr_device is not None else None
512
513
514
        # Get one shot input size limit for the current world size
        # for the current device capability
        max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
515
            device_capability,  # type: ignore[arg-type, unused-ignore]
516
            {},
517
518
519
520
521
        ).get(world_size, None)
        # Use one shot if no max size is specified
        use_oneshot = (
            max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
        )
522

523
524
525
526
527
528
        assert _FI_WORKSPACE_TENSOR is not None, (
            "Flashinfer must be enabled when using flashinfer"
        )
        if norm_out is None:
            norm_out = allreduce_in
            residual_out = residual
529
        else:
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            # return residual_out as allreduce_out with zeroed residual_in
            # as flashinfer does not support rms_norm
            # and allreduce_out together
            residual_out = allreduce_in
        # For the sizes that are smaller than the max size,
        # we only use flashinfer one shot allreduce
        flashinfer_comm.trtllm_allreduce_fusion(
            allreduce_in=allreduce_in,
            token_num=allreduce_in.shape[0],
            residual_in=residual,
            residual_out=residual_out,
            norm_out=norm_out,
            rms_gamma=rms_gamma,
            rms_eps=rms_eps,
            world_rank=world_rank,
            world_size=world_size,
            hidden_dim=allreduce_in.shape[-1],
            workspace_ptrs=_FI_WORKSPACE_TENSOR,
            launch_with_pdl=launch_with_pdl,
            use_oneshot=use_oneshot,
            trigger_completion_at_end=trigger_completion_at_end,
            fp32_acc=fp32_acc,
            pattern_code=pattern_code,
            allreduce_out=None,
            quant_out=quant_out,
            scale_out=scale_out,
            # in vllm we only support swizzled layout
            layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
            scale_factor=scale_factor,
        )
560
561

    def call_trtllm_fused_allreduce_norm_fake(
562
563
564
565
566
567
568
569
570
571
572
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
        pattern_code: int,
573
574
575
576
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
577
    ) -> None:
578
579
580
581
582
583
584
585
586
        pass

    direct_register_custom_op(
        op_name="flashinfer_trtllm_fused_allreduce_norm",
        op_func=call_trtllm_fused_allreduce_norm,
        mutates_args=[
            "allreduce_in",
            "residual",
            "norm_out",
587
588
            "quant_out",
            "scale_out",
589
590
591
592
        ],
        fake_impl=call_trtllm_fused_allreduce_norm_fake,
    )
    flashinfer_trtllm_fused_allreduce_norm = (
593
594
        torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
    )
595
596
597
598
599
600
601
602
603
604
605


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        rank: int,
        world_size: int,
        use_fp32_lamport: bool = False,
        max_token_num: int = 1024,
606
    ) -> None:
607
608
609
610
611
612
613
614
        self.rank = rank
        self.world_size = world_size
        self.use_fp32_lamport = use_fp32_lamport
        self.trigger_completion_at_end = True
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.max_token_num = max_token_num

615
    def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
616
617
618
619
620
621
622
623
624
625
        return {
            "world_rank": self.rank,
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "trigger_completion_at_end": self.trigger_completion_at_end,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
        }


626
627
class AllReduceRMSNormPattern(BasePattern):
    """
628
    This pattern replaces the allreduce + rms norm (without residual)
629
630
631
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """
632
633
634
635
636

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
637
        device: str | None,
638
        allreduce_params: FlashInferFusedAllReduceParams,
639
    ) -> None:
640
641
642
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
643
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
644

645
    def get_inputs(self) -> list[torch.Tensor]:
646
        input, weight = self.rmsnorm_matcher.inputs()
647

648
649
        # input goes through allreduce first, always 16-bit
        return [input.to(self.dtype), weight]
650

651
652
653
654
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
655
            allreduce_output = tensor_model_parallel_all_reduce(input)
656
            rms = self.rmsnorm_matcher(allreduce_output, weight)
657

658
659
            return rms, allreduce_output

660
661
662
        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
663
            residual = torch.zeros_like(input)
664
            rms_result = torch.empty_like(input)
665
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
666
            allreduce = auto_functionalized(
667
                flashinfer_trtllm_fused_allreduce_norm,
668
669
670
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
671
672
                quant_out=None,
                scale_out=None,
673
674
                rms_gamma=weight,
                rms_eps=self.epsilon,
675
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
676
677
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
678
            # rms_result, allreduce_in
679
680
            return allreduce[3], allreduce[1]

681
682
683
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
684
685
686


class AllReduceFusedAddRMSNormPattern(BasePattern):
687
    """
688
    This pattern replaces the allreduce + rms norm (with residual)
689
690
691
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """
692
693
694
695
696

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
697
        device: str | None,
698
        allreduce_params: FlashInferFusedAllReduceParams,
699
    ) -> None:
700
701
702
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
703
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
704

705
    def get_inputs(self) -> list[torch.Tensor]:
706
707
708
709
        input, residual, weight = self.rmsnorm_matcher.inputs()

        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight]
710

711
712
713
714
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
715
            allreduce_output = tensor_model_parallel_all_reduce(input)
716
717
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
            return rms, residual
718

719
720
        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
721
        ) -> tuple[torch.Tensor, torch.Tensor]:
722
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
723
            allreduce = auto_functionalized(
724
                flashinfer_trtllm_fused_allreduce_norm,
725
726
                allreduce_in=input,
                residual=residual,
727
728
729
                norm_out=None,
                quant_out=None,
                scale_out=None,
730
731
                rms_gamma=weight,
                rms_eps=self.epsilon,
732
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
733
734
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
735
            # allreduce_in, residual
736
737
            return allreduce[1], allreduce[2]

738
739
740
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
741

742
743
744
745
746
        # Same pattern, but only return the output and not residual
        # (helpful for end of graph where residual is not used again)
        first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]

        pm.register_replacement(
747
748
            first_return_only(pattern),  # type: ignore[no-untyped-call]
            first_return_only(replacement),  # type: ignore[no-untyped-call]
749
750
751
752
753
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
        )

754

755
756
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
757
    This pattern replaces the allreduce + rms norm (without residual)
758
    + static fp8 quant with fused flashinfer implementation.
759
    Applies to allreduce + rmsnorm + quant before attn
760
761
762
    in the first Transformer block.
    """

763
764
765
766
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
767
        device: str | None,
768
        allreduce_params: FlashInferFusedAllReduceParams,
769
    ) -> None:
770
771
772
773
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn
774
775
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
776

777
778
779
    def get_inputs(self) -> list[torch.Tensor]:
        input, weight = self.rmsnorm_matcher.inputs()
        _, scale = self.quant_matcher.inputs()
780

781
782
        # input goes through allreduce first, always 16-bit
        return [input.to(self.dtype), weight, scale]
783

784
    def register(self, pm_pass: PatternMatcherPass) -> None:
785
786
787
788
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
789
        ) -> tuple[torch.Tensor, torch.Tensor]:
790
            all_reduce = tensor_model_parallel_all_reduce(input)
791
792
793
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
794

795
796
797
        def replacement(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
798
            residual = torch.zeros_like(input)
799
800
            result_rms = torch.empty_like(input)
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
801
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
802
803
804
805
806
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
807
                quant_out=result_quant,
808
809
810
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
811
812
813
814
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
815
816
817
818
819
820
821
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

822
        pm.register_replacement(
823
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
824
        )
825
826
827
828
829
830


class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
831
    Applies to o_proj + rmsnorm after attn + quant and
832
833
834
    mlp + rmsnorm + quant before attn.
    """

835
836
837
838
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
839
        device: str | None,
840
        allreduce_params: FlashInferFusedAllReduceParams,
841
    ) -> None:
842
843
844
845
846
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

847
848
849
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

850
851
852
    def get_inputs(self) -> list[torch.Tensor]:
        input, residual, weight = self.rmsnorm_matcher.inputs()
        _, scale = self.quant_matcher.inputs()
853

854
855
        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight, scale]
856

857
    def register(self, pm_pass: PatternMatcherPass) -> None:
858
859
860
861
862
        def pattern(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
863
        ) -> tuple[torch.Tensor, torch.Tensor]:
864
            allreduce_output = tensor_model_parallel_all_reduce(input)
865
866
            rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
            quant, _ = self.quant_matcher(rms, scale)
867

868
            return quant, res
869

870
871
872
873
874
        def replacement(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
875
        ) -> tuple[torch.Tensor, torch.Tensor]:
876
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
877
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
878
879
880
881
882
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
883
                quant_out=result_quant,
884
885
886
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
887
888
889
890
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
891
892
893
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
894
            # quant_out, rms_norm_residual
895
896
            return allreduce[4], allreduce[2]

897
        pm.register_replacement(
898
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
899
        )
900
901
902
903


class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
904
    This pattern replaces the allreduce + rms norm (without residual)
905
    + static nvfp4 quant with fused flashinfer implementation.
906
    Applies to allreduce + rmsnorm + quant before attn
907
908
909
    in the first Transformer block.
    """

910
911
912
913
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
914
        device: str | None,
915
        allreduce_params: FlashInferFusedAllReduceParams,
916
    ) -> None:
917
918
919
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
920
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
921

922
923
924
925
926
927
928
929
    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
        quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
        input_global_scale = torch.empty(
            [1, 1], device=self.device, dtype=torch.float32
        )
        weight = torch.empty([16], device=self.device, dtype=self.dtype)
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
930

931
        return [input, quant_result, weight, input_global_scale, output_scale]
932

933
    def register(self, pm_pass: PatternMatcherPass) -> None:
934
935
936
937
938
939
        def pattern(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
940
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
941
            all_reduce = tensor_model_parallel_all_reduce(input)
942
            rms = self.rmsnorm_matcher(all_reduce, weight)
943
944
945
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
946
                input=rms,
947
                output_scale=output_scale,
948
                input_scale=input_global_scale,
949
                is_sf_swizzled_layout=True,
950
            )
951
952
953
954

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

955
956
957
958
959
960
        def replacement(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
961
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
962
            residual = torch.zeros_like(input)
963
            result_rms = torch.empty_like(input)
964
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
965
966
967
968
969
970
971
972
973
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
974
975
976
977
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
978
979
980
981
982
983
984
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

985
        pm.register_replacement(
986
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
987
        )
988
989
990
991
992
993


class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
994
    Applies to o_proj + rmsnorm after attn + quant and
995
996
997
    mlp + rmsnorm + quant before attn.
    """

998
999
1000
1001
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
1002
        device: str | None,
1003
        allreduce_params: FlashInferFusedAllReduceParams,
1004
    ) -> None:
1005
1006
1007
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
1008
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
1009

1010
1011
    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
1012

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
        input_global_scale = torch.empty(
            [1, 1], device=self.device, dtype=torch.float32
        )
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)

        return [
            quant_result,
            residual,
            input,
            output_scale,
            weight,
            input_global_scale,
        ]

    def register(self, pm_pass: PatternMatcherPass) -> None:
1031
1032
1033
1034
1035
1036
1037
        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
1038
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1039
            allreduce_output = tensor_model_parallel_all_reduce(input)
1040
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
1041
1042
1043
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
1044
                input=rms,
1045
                output_scale=output_scale,
1046
                input_scale=input_global_scale,
1047
                is_sf_swizzled_layout=True,
1048
            )
1049
1050

            # quant_out, allreduce_output, output_scale
1051
            return quant_out_tuple[1], residual, quant_out_tuple[2]
1052

1053
1054
1055
1056
1057
1058
1059
        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
1060
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1061
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
1062
1063
1064
1065
1066
1067
1068
1069
1070
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
1071
1072
1073
1074
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
1075
1076
1077
1078
1079
1080
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

1081
        pm.register_replacement(
1082
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
1083
        )
1084
1085


1086
class AllReduceFusionPass(VllmPatternMatcherPass):
1087
    def __init__(self, config: VllmConfig) -> None:
1088
1089
1090
1091
        super().__init__(config)
        self.disabled = True
        self.tp_size = get_tensor_model_parallel_world_size()
        if self.tp_size <= 1:
1092
            logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
1093
1094
            return
        self.patterns: PatternMatcherPass = PatternMatcherPass(
1095
1096
            pass_name="all_reduce_fusion_pass"
        )
1097
        if config.model_config is None:
1098
1099
1100
            logger.warning_once(
                "AllReduce fusion pass is disabled for missing model_config."
            )
1101
1102
1103
1104
1105
1106
1107
            return
        self.hidden_dim = config.model_config.get_hidden_size()
        self.group = get_tp_group().device_group
        rank = get_tensor_model_parallel_rank()
        use_fp32_lamport = self.model_dtype == torch.float32
        if flashinfer_comm is None:
            logger.warning(
1108
                "Flashinfer is not installed or comm module not found, "
1109
1110
                "skipping allreduce fusion pass"
            )
1111
            return
1112
1113
1114
1115
1116
        max_size = config.compilation_config.pass_config.flashinfer_max_size(
            self.tp_size
        )
        if max_size is None:
            # Flashinfer doesn't support current world size
1117
            logger.warning(
1118
1119
                "Flashinfer allreduce fusion is not supported for world size %s"
                " or max size is not provided",
1120
1121
1122
                self.tp_size,
            )
            return
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        element_size = 4 if use_fp32_lamport else 2
        self.max_token_num = max_size // (self.hidden_dim * element_size)
        # take the min to save workspace size and we'll never use more
        # than max_num_batched_tokens anyways
        self.max_token_num = min(
            self.max_token_num, config.scheduler_config.max_num_batched_tokens
        )
        logger.debug_once(
            f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
            "Maximal number of tokens used by "
            f"Flashinfer Allreduce Fusion: {self.max_token_num}",
            scope="global",
1135
        )
1136

1137
        self.ipc_handles, workspace_tensor = (
1138
            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
1139
1140
                tp_rank=rank,
                tp_size=self.tp_size,
1141
                max_token_num=self.max_token_num,
1142
1143
1144
                hidden_dim=self.hidden_dim,
                group=self.group,
                use_fp32_lamport=use_fp32_lamport,
1145
1146
            )
        )
1147
1148
1149
1150
1151
1152
1153

        global _FI_WORKSPACE_TENSOR
        _FI_WORKSPACE_TENSOR = workspace_tensor
        self.allreduce_params = FlashInferFusedAllReduceParams(
            rank=rank,
            world_size=self.tp_size,
            use_fp32_lamport=use_fp32_lamport,
1154
            max_token_num=self.max_token_num,
1155
        )
1156

1157
        self.register_patterns()
1158
        self.dump_patterns(config, self.patterns)
1159
1160

    @enable_fake_mode
1161
    def register_patterns(self) -> None:
1162
        for epsilon in [1e-5, 1e-6]:
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
            AllReduceFusedRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            if current_platform.has_device_capability(100):
                AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
                AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
            AllReduceRMSNormPattern(
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormPattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)

1201
1202
1203
1204
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

1205
1206
        self.disabled = False

1207
    def is_applicable_for_range(self, compile_range: Range) -> bool:
1208
1209
1210
        if self.disabled:
            logger.warning_once("AllReduce fusion pass is disabled.")
            return False
1211
        return bool(compile_range.end <= self.max_token_num)
1212

1213
    @VllmInductorPass.time_and_log
1214
    def __call__(self, graph: fx.Graph) -> None:
1215
        if self.disabled:
1216
            logger.debug("AllReduceFusionPass disabled")
1217
            return
1218
1219
1220

        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
1221

1222
    def __del__(self) -> None:
1223
        if getattr(self, "disabled", True):
1224
1225
            return
        if flashinfer_comm is not None:
1226
            flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
1227
1228
                self.ipc_handles, self.group
            )