"tests/models/multimodal/generation/test_common.py" did not exist on "5c9121203cc34e781f1f249b69cb789244e861f0"
collective_fusion.py 43.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from importlib.util import find_spec
4
from types import ModuleType
5
6
7
8

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
9
from torch._higher_order_ops.auto_functionalize import auto_functionalized
10
11
12
13
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

from vllm.config import VllmConfig
14
from vllm.config.utils import Range
15
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
16
from vllm.distributed.parallel_state import (
17
18
19
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
20
from vllm.logger import init_logger
21
22
23
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
24
from vllm.platforms import current_platform
25
from vllm.utils.torch_utils import direct_register_custom_op
26

27
from .inductor_pass import enable_fake_mode
28
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
29
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
30

31
32
FP8_DTYPE = current_platform.fp8_dtype()

33
if find_spec("flashinfer"):
34
35
    try:
        import flashinfer.comm as flashinfer_comm
36

37
        flashinfer_comm: ModuleType | None = (  # type: ignore[no-redef]
38
39
40
41
            flashinfer_comm
            if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
            else None
        )
42
    except ImportError:
43
        flashinfer_comm = None  # type: ignore[assignment]
44
else:
45
    flashinfer_comm = None  # type: ignore[assignment]
46

47
48
logger = init_logger(__name__)

49
50
if hasattr(torch.ops._C, "scaled_fp4_quant"):
    STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
51

52
53

class BasePattern:
54
    def __init__(self, dtype: torch.dtype, device: str | None) -> None:
55
56
57
58
59
60
61
        self.dtype = dtype
        self.device = device
        self.tp = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()


class GEMMReduceScatterPattern(BasePattern):
62
    def get_inputs(self) -> list[torch.Tensor]:
63
64
65
66
        mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
        mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        return [mul, mm_weight]

67
68
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
69
70
71
72
73
            mm = torch.ops.aten.mm.default(mul, mm_weight)
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                mm,
                dim=0,
                world_size=self.tp_size,
74
75
                group_name=self.tp.unique_name,
            )
76
77
            return reduce_scatter

78
        def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
79
80
81
82
83
84
85
86
87
88
            gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
                mul,
                mm_weight,
                "avg",
                scatter_dim=0,
                group_name=self.tp.device_group.group_name,
            )

            return gemm_rs

89
90
91
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
92
93
94


class AllGatherGEMMPattern(BasePattern):
95
    def get_inputs(self) -> list[torch.Tensor]:
96
97
98
99
100
        x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        return [x, weight]

101
    def register(self, pm_pass: PatternMatcherPass) -> None:
102
103
104
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
105
        ) -> torch.Tensor:
106
107
108
109
            all_gather = torch.ops.vllm.all_gather.default(
                x,
                dim=0,
                world_size=self.tp_size,
110
111
                group_name=self.tp.unique_name,
            )
112
113
114

            return torch.ops.aten.mm.default(all_gather, weight)

115
        def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
116
117
118
119
120
121
122
123
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
                x,
                [weight],
                gather_dim=0,
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

124
125
126
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
127
128


129
class ScaledMMReduceScatterPattern(BasePattern):
130
    def get_inputs(self) -> list[torch.Tensor]:
131
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
132
133
134
135
136
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
137
138
139
140
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
        return [input, mm_weight, scale_a, scale_b]

141
    def register(self, pm_pass: PatternMatcherPass) -> None:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        def pattern(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            scaled_mm = torch.ops.aten._scaled_mm.default(
                input,
                mat2=mat2,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )
157
158
159
160
            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                scaled_mm,
                dim=0,
                world_size=self.tp_size,
161
162
                group_name=self.tp.unique_name,
            )
163
164
            return reduce_scatter

165
166
167
168
169
170
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
171
172
173
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
174
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
175
176
177
178
179
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
180
181
182
183
184
185
186
187
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
188
189
190
191
            )

            return gemm_rs

192
193
194
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
195
196
197


class AllGatherScaledMMPattern(BasePattern):
198
    def get_inputs(self) -> list[torch.Tensor]:
199
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
200
201
202
203
204
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
205
206
207
208
209
210
211
212

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        return [x, weight, scale_a, scale_b]

213
    def register(self, pm_pass: PatternMatcherPass) -> None:
214
215
216
217
218
219
220
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )

            return torch.ops.aten._scaled_mm.default(
                all_gather,
                mat2=weight,
                scale_a=scale_a,
                scale_b=scale_b,
                bias=None,
                scale_result=None,
                out_dtype=self.dtype,
            )

        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
        ) -> torch.Tensor:
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

254
255
256
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
257
258
259


class CutlassScaledMMReduceScatterPattern(BasePattern):
260
    def get_inputs(self) -> list[torch.Tensor]:
261
        input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
262
263
264
265
266
        mm_weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
267
268
269
        scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

270
        cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
271
272
        return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]

273
    def register(self, pm_pass: PatternMatcherPass) -> None:
274
275
276
277
278
279
280
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
281
282
283
284
285
286
287
            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=cutlass_mm_output,
                a=input,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
288
289
                bias=None,
            )
290
291
292
293
294

            reduce_scatter = torch.ops.vllm.reduce_scatter.default(
                cutlass_scaled_mm[1],
                dim=0,
                world_size=self.tp_size,
295
296
                group_name=self.tp.unique_name,
            )
297
298
            return reduce_scatter

299
300
301
302
303
304
305
        def replacement(
            input: torch.Tensor,
            mat2: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            cutlass_mm_output: torch.Tensor,
        ) -> torch.Tensor:
306
307
308
            # Calculate output shape: input @ mat2 with scatter_dim reduced
            output_shape = [*input.shape[:-1], mat2.shape[1]]
            scatter_dim = 0
309
            gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
310
311
312
313
314
                input,
                mat2,
                scale_a,
                scale_b,
                "avg",
315
316
317
318
319
320
321
322
                scatter_dim,  # orig_scatter_dim
                scatter_dim,  # scatter_dim_after_maybe_reshape
                self.tp.device_group.group_name,
                output_shape,
                None,  # bias
                None,  # result_scale
                self.dtype,  # out_dtype
                False,  # use_fast_accum
323
324
325
326
            )

            return gemm_rs

327
328
329
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
330
331
332


class AllGatherCutlassScaledMMPattern(BasePattern):
333
    def get_inputs(self) -> list[torch.Tensor]:
334
        x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
335
336
337
338
339
        weight = (
            torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
340
341
342
343
344
345
346
347
348
349
350

        s1 = x.shape[0] * self.tp_size

        scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
        scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)

        s2 = weight.shape[1]
        output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)

        return [x, weight, scale_a, scale_b, output]

351
    def register(self, pm_pass: PatternMatcherPass) -> None:
352
353
354
355
356
357
358
359
        def pattern(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
            all_gather = torch.ops.vllm.all_gather.default(
360
361
                x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
            )
362
363
364
365
366
367
368
369

            cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
                torch.ops._C.cutlass_scaled_mm.default,
                out=output,
                a=all_gather,
                b=weight,
                a_scales=scale_a,
                b_scales=scale_b,
370
371
                bias=None,
            )
372
373
            return cutlass_scaled_mm[1]

374
375
376
377
378
379
380
        def replacement(
            x: torch.Tensor,
            weight: torch.Tensor,
            scale_a: torch.Tensor,
            scale_b: torch.Tensor,
            output: torch.Tensor,
        ) -> torch.Tensor:
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul(  # noqa
                x,
                [weight],
                scale_a,
                [scale_b],
                gather_dim=0,
                biases=[None],
                result_scales=[None],
                out_dtypes=[self.dtype],
                use_fast_accum=[False],
                group_name=self.tp.device_group.group_name,
            )
            return mm_outputs

395
396
397
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
398
399


400
class AsyncTPPass(VllmPatternMatcherPass):
401
    @enable_fake_mode
402
    def __init__(self, config: VllmConfig) -> None:
403
404
405
406
407
        super().__init__(config)

        # Enable symmetric memory for the TP process group
        enable_symm_mem_for_group(get_tp_group().device_group.group_name)
        self.patterns: PatternMatcherPass = PatternMatcherPass(
408
409
410
            pass_name="async_tp_pass"
        )
        GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
411

412
        AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
413

414
415
416
417
        # These fusions are enabled only for bfloat16 models because
        # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
        # only supports bfloat16 as the output dtype.
        if self.model_dtype == torch.bfloat16:
418
419
420
421
422
423
            ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
424

425
426
427
428
429
430
            CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
                self.patterns
            )
            AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
                self.patterns
            )
431

432
433
        self.dump_patterns(config, self.patterns)

434
    def is_applicable_for_range(self, compile_range: Range) -> bool:
435
436
437
438
439
440
441
442
        # This pass is applied on top of the sequence parallelism pass.
        # It inherits the same applicability condition as `SequenceParallelismPass`.
        # See `SequenceParallelismPass.is_applicable` for more details.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
443
        tp_size = get_tensor_model_parallel_world_size()
444
        return compile_range.is_single_size() and compile_range.end % tp_size == 0
445

446
    @VllmInductorPass.time_and_log
447
    def __call__(self, graph: fx.Graph) -> None:
448
449
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
450
451


452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
    90: {
        2: 64,  # 64MB
        4: 2,  # 2MB
        8: 0.5,  # 0.5MB
    },
    100: {
        2: 64,  # 64MB
        4: 32,  # 32MB
        8: 1,  # 1MB
    },
}

# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
    90: {
        2: 32,  # 32MB
        4: 2,  # 2MB
        8: 0.5,  # 0.5MB
    },
    100: {
        2: 32,  # 32MB
        4: 4,  # 4MB
        8: 1,  # 1MB
    },
}


484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
if flashinfer_comm is not None:
    _FI_WORKSPACE_TENSOR = None
    MiB = 1024 * 1024

    def call_trtllm_fused_allreduce_norm(
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
499
        pattern_code: int,
500
501
502
503
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
504
    ) -> None:
505
506
507
        num_tokens, hidden_size = allreduce_in.shape
        element_size = allreduce_in.element_size()
        current_tensor_size = num_tokens * hidden_size * element_size
508
509
510
511
512
513
        max_tensor_size = max_token_num * hidden_size * element_size
        assert current_tensor_size <= max_tensor_size, (
            f"Current tensor size {current_tensor_size} is larger than "
            f"max token num {max_token_num} * hidden size {hidden_size} * "
            f"element size {element_size}"
        )
514
515
        curr_device = current_platform.get_device_capability()
        device_capability = curr_device.to_int() if curr_device is not None else None
516
517
518
        # Get one shot input size limit for the current world size
        # for the current device capability
        max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
519
520
            device_capability,  # type: ignore[arg-type]
            {},
521
522
523
524
525
        ).get(world_size, None)
        # Use one shot if no max size is specified
        use_oneshot = (
            max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
        )
526

527
528
529
530
531
532
        assert _FI_WORKSPACE_TENSOR is not None, (
            "Flashinfer must be enabled when using flashinfer"
        )
        if norm_out is None:
            norm_out = allreduce_in
            residual_out = residual
533
        else:
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
            # return residual_out as allreduce_out with zeroed residual_in
            # as flashinfer does not support rms_norm
            # and allreduce_out together
            residual_out = allreduce_in
        # For the sizes that are smaller than the max size,
        # we only use flashinfer one shot allreduce
        flashinfer_comm.trtllm_allreduce_fusion(
            allreduce_in=allreduce_in,
            token_num=allreduce_in.shape[0],
            residual_in=residual,
            residual_out=residual_out,
            norm_out=norm_out,
            rms_gamma=rms_gamma,
            rms_eps=rms_eps,
            world_rank=world_rank,
            world_size=world_size,
            hidden_dim=allreduce_in.shape[-1],
            workspace_ptrs=_FI_WORKSPACE_TENSOR,
            launch_with_pdl=launch_with_pdl,
            use_oneshot=use_oneshot,
            trigger_completion_at_end=trigger_completion_at_end,
            fp32_acc=fp32_acc,
            pattern_code=pattern_code,
            allreduce_out=None,
            quant_out=quant_out,
            scale_out=scale_out,
            # in vllm we only support swizzled layout
            layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
            scale_factor=scale_factor,
        )
564
565

    def call_trtllm_fused_allreduce_norm_fake(
566
567
568
569
570
571
572
573
574
575
576
        allreduce_in: torch.Tensor,
        residual: torch.Tensor,
        rms_gamma: torch.Tensor,
        rms_eps: float,
        world_rank: int,
        world_size: int,
        launch_with_pdl: bool,
        trigger_completion_at_end: bool,
        fp32_acc: bool,
        max_token_num: int,
        pattern_code: int,
577
578
579
580
        norm_out: torch.Tensor | None = None,
        quant_out: torch.Tensor | None = None,
        scale_out: torch.Tensor | None = None,
        scale_factor: torch.Tensor | None = None,
581
    ) -> None:
582
583
584
585
586
587
588
589
590
        pass

    direct_register_custom_op(
        op_name="flashinfer_trtllm_fused_allreduce_norm",
        op_func=call_trtllm_fused_allreduce_norm,
        mutates_args=[
            "allreduce_in",
            "residual",
            "norm_out",
591
592
            "quant_out",
            "scale_out",
593
594
595
596
        ],
        fake_impl=call_trtllm_fused_allreduce_norm_fake,
    )
    flashinfer_trtllm_fused_allreduce_norm = (
597
598
        torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
    )
599
600
601
602
603
604
605
606
607
608
609


class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        rank: int,
        world_size: int,
        use_fp32_lamport: bool = False,
        max_token_num: int = 1024,
610
    ) -> None:
611
612
613
614
615
616
617
618
        self.rank = rank
        self.world_size = world_size
        self.use_fp32_lamport = use_fp32_lamport
        self.trigger_completion_at_end = True
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.max_token_num = max_token_num

619
    def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
620
621
622
623
624
625
626
627
628
629
        return {
            "world_rank": self.rank,
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "trigger_completion_at_end": self.trigger_completion_at_end,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
        }


630
631
class AllReduceRMSNormPattern(BasePattern):
    """
632
    This pattern replaces the allreduce + rms norm (without residual)
633
634
635
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """
636
637
638
639
640

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
641
        device: str | None,
642
        allreduce_params: FlashInferFusedAllReduceParams,
643
    ) -> None:
644
645
646
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
647
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
648

649
    def get_inputs(self) -> list[torch.Tensor]:
650
        input, weight = self.rmsnorm_matcher.inputs()
651

652
653
        # input goes through allreduce first, always 16-bit
        return [input.to(self.dtype), weight]
654

655
656
657
658
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
659
            allreduce_output = tensor_model_parallel_all_reduce(input)
660
            rms = self.rmsnorm_matcher(allreduce_output, weight)
661

662
663
            return rms, allreduce_output

664
665
666
        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
667
            residual = torch.zeros_like(input)
668
            rms_result = torch.empty_like(input)
669
            allreduce = auto_functionalized(
670
                flashinfer_trtllm_fused_allreduce_norm,
671
672
673
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
674
675
                quant_out=None,
                scale_out=None,
676
677
                rms_gamma=weight,
                rms_eps=self.epsilon,
678
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
679
680
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
681
            # rms_result, allreduce_in
682
683
            return allreduce[3], allreduce[1]

684
685
686
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
687
688
689


class AllReduceFusedAddRMSNormPattern(BasePattern):
690
    """
691
    This pattern replaces the allreduce + rms norm (with residual)
692
693
694
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """
695
696
697
698
699

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
700
        device: str | None,
701
        allreduce_params: FlashInferFusedAllReduceParams,
702
    ) -> None:
703
704
705
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
706
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
707

708
    def get_inputs(self) -> list[torch.Tensor]:
709
710
711
712
        input, residual, weight = self.rmsnorm_matcher.inputs()

        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight]
713

714
715
716
717
    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
718
            allreduce_output = tensor_model_parallel_all_reduce(input)
719
720
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
            return rms, residual
721

722
723
        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
724
        ) -> tuple[torch.Tensor, torch.Tensor]:
725
            allreduce = auto_functionalized(
726
                flashinfer_trtllm_fused_allreduce_norm,
727
728
                allreduce_in=input,
                residual=residual,
729
730
731
                norm_out=None,
                quant_out=None,
                scale_out=None,
732
733
                rms_gamma=weight,
                rms_eps=self.epsilon,
734
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
735
736
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
737
            # allreduce_in, residual
738
739
            return allreduce[1], allreduce[2]

740
741
742
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
743

744
745
746
747
748
        # Same pattern, but only return the output and not residual
        # (helpful for end of graph where residual is not used again)
        first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]

        pm.register_replacement(
749
750
            first_return_only(pattern),  # type: ignore[no-untyped-call]
            first_return_only(replacement),  # type: ignore[no-untyped-call]
751
752
753
754
755
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
        )

756

757
758
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
759
    This pattern replaces the allreduce + rms norm (without residual)
760
    + static fp8 quant with fused flashinfer implementation.
761
    Applies to allreduce + rmsnorm + quant before attn
762
763
764
    in the first Transformer block.
    """

765
766
767
768
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
769
        device: str | None,
770
        allreduce_params: FlashInferFusedAllReduceParams,
771
    ) -> None:
772
773
774
775
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn
776
777
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
778

779
780
781
    def get_inputs(self) -> list[torch.Tensor]:
        input, weight = self.rmsnorm_matcher.inputs()
        _, scale = self.quant_matcher.inputs()
782

783
784
        # input goes through allreduce first, always 16-bit
        return [input.to(self.dtype), weight, scale]
785

786
    def register(self, pm_pass: PatternMatcherPass) -> None:
787
788
789
790
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
791
        ) -> tuple[torch.Tensor, torch.Tensor]:
792
            all_reduce = tensor_model_parallel_all_reduce(input)
793
794
795
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
796

797
798
799
        def replacement(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
800
            residual = torch.zeros_like(input)
801
802
            result_rms = torch.empty_like(input)
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
803
804
805
806
807
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
808
                quant_out=result_quant,
809
810
811
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
812
813
814
815
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
816
817
818
819
820
821
822
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

823
        pm.register_replacement(
824
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
825
        )
826
827
828
829
830
831


class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
832
    Applies to o_proj + rmsnorm after attn + quant and
833
834
835
    mlp + rmsnorm + quant before attn.
    """

836
837
838
839
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
840
        device: str | None,
841
        allreduce_params: FlashInferFusedAllReduceParams,
842
    ) -> None:
843
844
845
846
847
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

848
849
850
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

851
852
853
    def get_inputs(self) -> list[torch.Tensor]:
        input, residual, weight = self.rmsnorm_matcher.inputs()
        _, scale = self.quant_matcher.inputs()
854

855
856
        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight, scale]
857

858
    def register(self, pm_pass: PatternMatcherPass) -> None:
859
860
861
862
863
        def pattern(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
864
        ) -> tuple[torch.Tensor, torch.Tensor]:
865
            allreduce_output = tensor_model_parallel_all_reduce(input)
866
867
            rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
            quant, _ = self.quant_matcher(rms, scale)
868

869
            return quant, res
870

871
872
873
874
875
        def replacement(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
876
        ) -> tuple[torch.Tensor, torch.Tensor]:
877
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
878
879
880
881
882
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
883
                quant_out=result_quant,
884
885
886
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
887
888
889
890
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
891
892
893
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
894
            # quant_out, rms_norm_residual
895
896
            return allreduce[4], allreduce[2]

897
        pm.register_replacement(
898
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
899
        )
900
901
902
903


class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
904
    This pattern replaces the allreduce + rms norm (without residual)
905
    + static nvfp4 quant with fused flashinfer implementation.
906
    Applies to allreduce + rmsnorm + quant before attn
907
908
909
    in the first Transformer block.
    """

910
911
912
913
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
914
        device: str | None,
915
        allreduce_params: FlashInferFusedAllReduceParams,
916
    ) -> None:
917
918
919
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
920
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
921

922
923
924
925
926
927
928
929
    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
        quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
        input_global_scale = torch.empty(
            [1, 1], device=self.device, dtype=torch.float32
        )
        weight = torch.empty([16], device=self.device, dtype=self.dtype)
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
930

931
        return [input, quant_result, weight, input_global_scale, output_scale]
932

933
    def register(self, pm_pass: PatternMatcherPass) -> None:
934
935
936
937
938
939
        def pattern(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
940
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
941
            all_reduce = tensor_model_parallel_all_reduce(input)
942
            rms = self.rmsnorm_matcher(all_reduce, weight)
943
944
945
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
946
                input=rms,
947
                output_scale=output_scale,
948
949
                input_scale=input_global_scale,
            )
950
951
952
953

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

954
955
956
957
958
959
        def replacement(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
960
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
961
            residual = torch.zeros_like(input)
962
            result_rms = torch.empty_like(input)
963
964
965
966
967
968
969
970
971
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
972
973
974
975
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
976
977
978
979
980
981
982
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

983
        pm.register_replacement(
984
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
985
        )
986
987
988
989
990
991


class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
992
    Applies to o_proj + rmsnorm after attn + quant and
993
994
995
    mlp + rmsnorm + quant before attn.
    """

996
997
998
999
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
1000
        device: str | None,
1001
        allreduce_params: FlashInferFusedAllReduceParams,
1002
    ) -> None:
1003
1004
1005
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
1006
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
1007

1008
1009
    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
1010

1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
        input_global_scale = torch.empty(
            [1, 1], device=self.device, dtype=torch.float32
        )
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)

        return [
            quant_result,
            residual,
            input,
            output_scale,
            weight,
            input_global_scale,
        ]

    def register(self, pm_pass: PatternMatcherPass) -> None:
1029
1030
1031
1032
1033
1034
1035
        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
1036
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1037
            allreduce_output = tensor_model_parallel_all_reduce(input)
1038
            rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
1039
1040
1041
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                output=quant_result,
1042
                input=rms,
1043
                output_scale=output_scale,
1044
1045
                input_scale=input_global_scale,
            )
1046
1047

            # quant_out, allreduce_output, output_scale
1048
            return quant_out_tuple[1], residual, quant_out_tuple[2]
1049

1050
1051
1052
1053
1054
1055
1056
        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
1057
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1058
1059
1060
1061
1062
1063
1064
1065
1066
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
1067
1068
1069
1070
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
1071
1072
1073
1074
1075
1076
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

1077
        pm.register_replacement(
1078
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
1079
        )
1080
1081


1082
class AllReduceFusionPass(VllmPatternMatcherPass):
1083
    def __init__(self, config: VllmConfig) -> None:
1084
1085
1086
1087
        super().__init__(config)
        self.disabled = True
        self.tp_size = get_tensor_model_parallel_world_size()
        if self.tp_size <= 1:
1088
            logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
1089
1090
            return
        self.patterns: PatternMatcherPass = PatternMatcherPass(
1091
1092
            pass_name="all_reduce_fusion_pass"
        )
1093
        if config.model_config is None:
1094
1095
1096
            logger.warning_once(
                "AllReduce fusion pass is disabled for missing model_config."
            )
1097
1098
1099
1100
1101
1102
1103
            return
        self.hidden_dim = config.model_config.get_hidden_size()
        self.group = get_tp_group().device_group
        rank = get_tensor_model_parallel_rank()
        use_fp32_lamport = self.model_dtype == torch.float32
        if flashinfer_comm is None:
            logger.warning(
1104
                "Flashinfer is not installed or comm module not found, "
1105
1106
                "skipping allreduce fusion pass"
            )
1107
            return
1108
1109
1110
1111
1112
        max_size = config.compilation_config.pass_config.flashinfer_max_size(
            self.tp_size
        )
        if max_size is None:
            # Flashinfer doesn't support current world size
1113
            logger.warning(
1114
1115
                "Flashinfer allreduce fusion is not supported for world size %s"
                " or max size is not provided",
1116
1117
1118
                self.tp_size,
            )
            return
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        element_size = 4 if use_fp32_lamport else 2
        self.max_token_num = max_size // (self.hidden_dim * element_size)
        # take the min to save workspace size and we'll never use more
        # than max_num_batched_tokens anyways
        self.max_token_num = min(
            self.max_token_num, config.scheduler_config.max_num_batched_tokens
        )
        logger.debug_once(
            f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
            "Maximal number of tokens used by "
            f"Flashinfer Allreduce Fusion: {self.max_token_num}",
            scope="global",
1131
        )
1132

1133
        self.ipc_handles, workspace_tensor = (
1134
            flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(  # type: ignore[misc]
1135
1136
                tp_rank=rank,
                tp_size=self.tp_size,
1137
                max_token_num=self.max_token_num,
1138
1139
1140
                hidden_dim=self.hidden_dim,
                group=self.group,
                use_fp32_lamport=use_fp32_lamport,
1141
1142
            )
        )
1143
1144
1145
1146
1147
1148
1149

        global _FI_WORKSPACE_TENSOR
        _FI_WORKSPACE_TENSOR = workspace_tensor
        self.allreduce_params = FlashInferFusedAllReduceParams(
            rank=rank,
            world_size=self.tp_size,
            use_fp32_lamport=use_fp32_lamport,
1150
            max_token_num=self.max_token_num,
1151
        )
1152

1153
        self.register_patterns()
1154
        self.dump_patterns(config, self.patterns)
1155
1156

    @enable_fake_mode
1157
    def register_patterns(self) -> None:
1158
        for epsilon in [1e-5, 1e-6]:
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
            AllReduceFusedRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            if current_platform.has_device_capability(100):
                AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
                AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
                    epsilon,
                    self.model_dtype,
                    self.device,
                    self.allreduce_params,
                ).register(self.patterns)
            AllReduceRMSNormPattern(
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)
            AllReduceFusedAddRMSNormPattern(
                epsilon,
                self.model_dtype,
                self.device,
                self.allreduce_params,
            ).register(self.patterns)

1197
1198
1199
1200
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

1201
1202
        self.disabled = False

1203
    def is_applicable_for_range(self, compile_range: Range) -> bool:
1204
1205
1206
        if self.disabled:
            logger.warning_once("AllReduce fusion pass is disabled.")
            return False
1207
1208
        return compile_range.end <= self.max_token_num

1209
    @VllmInductorPass.time_and_log
1210
    def __call__(self, graph: fx.Graph) -> None:
1211
        if self.disabled:
1212
            logger.debug("AllReduceFusionPass disabled")
1213
            return
1214
1215
1216

        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
1217

1218
    def __del__(self) -> None:
1219
        if getattr(self, "disabled", True):
1220
1221
            return
        if flashinfer_comm is not None:
1222
            flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
1223
1224
                self.ipc_handles, self.group
            )