sequence_parallelism.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
11
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
12
from vllm.logger import init_logger
13
from vllm.platforms import current_platform
14

15
from .inductor_pass import enable_fake_mode
16
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
17
18
19
20

logger = init_logger(__name__)


21
22
class _RMSNormAndQuantOpHelper:
    """Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
23

24
25
26
27
28
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
29
        quant_op: torch._ops.OpOverload | None = None,
30
31
        **kwargs,
    ):
32
33
34
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
35
36
37
38
39
40
41
42
        self.quant_op = quant_op

    def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
        return torch.ops.higher_order.auto_functionalized(
            torch.ops._C.rms_norm.default,
            result=result_buffer,
            input=input_tensor,
            weight=weight_tensor,
43
44
            epsilon=self.epsilon,
        )
45

46
47
48
    def _functional_fused_add_rmsnorm(
        self, input_tensor, residual_tensor, weight_tensor
    ):
49
50
51
52
53
        return torch.ops.higher_order.auto_functionalized(
            torch.ops._C.fused_add_rms_norm.default,
            input=input_tensor,
            residual=residual_tensor,
            weight=weight_tensor,
54
55
56
57
58
59
60
61
62
63
64
            epsilon=self.epsilon,
        )

    def _functional_rmsnorm_then_quant(
        self,
        rmsnorm_result_buffer,
        quant_result_buffer,
        input_tensor,
        weight_tensor,
        scale_tensor,
    ):
65
66
67
68
        if self.quant_op is None:
            raise RuntimeError(
                "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
            )
69
70
71
        rmsnorm_out_tuple = self._functional_rmsnorm(
            rmsnorm_result_buffer, input_tensor, weight_tensor
        )
72
73
74
75
        quant_out_tuple = torch.ops.higher_order.auto_functionalized(
            self.quant_op,
            result=quant_result_buffer,
            input=rmsnorm_out_tuple[1],
76
77
            scale=scale_tensor,
        )
78
79
        return quant_out_tuple

80
81
82
83
84
85
86
87
    def _functional_fused_add_rmsnorm_then_quant(
        self,
        quant_result_buffer,
        input_tensor,
        residual_tensor,
        weight_tensor,
        scale_tensor,
    ):
88
89
90
91
92
        if self.quant_op is None:
            raise RuntimeError(
                "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
            )
        fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
93
94
            input_tensor, residual_tensor, weight_tensor
        )
95
96
97
98
        quant_out_tuple = torch.ops.higher_order.auto_functionalized(
            self.quant_op,
            result=quant_result_buffer,
            input=fused_add_rmsnorm_out_tuple[1],
99
100
            scale=scale_tensor,
        )
101
102
103
104
105
106
        return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]


class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
    """Helper for sequence parallelism patterns."""

107
108
109
110
111
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
112
        quant_op: torch._ops.OpOverload | None = None,
113
114
        **kwargs,
    ):
115
116
117
118
119
120
121
122
123
        super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
124
125
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
126
127
128

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
129
130
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
131
132
133


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
134
    def get_inputs(self):
135
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
136
137
138
        permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

139
        return [input, permute, arg3_1]
140
141
142

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
143
            input: torch.Tensor,
144
145
146
            permute: torch.Tensor,
            arg3_1: torch.Tensor,
        ):
147
148
            all_reduce = self._all_reduce(input)
            rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
149
150
151
152

            return rmsnorm[1], all_reduce

        def replacement(
153
            input: torch.Tensor,
154
155
156
            permute: torch.Tensor,
            arg3_1: torch.Tensor,
        ):
157
            reduce_scatter = self._reduce_scatter(input)
158
159

            rmsnorm_result = torch.empty_like(reduce_scatter)
160
            rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
161

162
            all_gather = self._all_gather(rmsnorm[1])
163
164
165

            return all_gather, reduce_scatter

166
167
168
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
169
170


171
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
172
173
174
175
    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
176
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
177
178
179
180
181
182
183
184
185
186
187
188

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
189
        ) -> tuple[torch.Tensor, torch.Tensor]:
190
191
            all_reduce = self._all_reduce(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
192
193
                all_reduce, residual, rms_norm_weights
            )
194
195
196
197
198
199
            return rmsnorm[1], rmsnorm[2]

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
200
        ) -> tuple[torch.Tensor, torch.Tensor]:
201
202
            reduce_scatter = self._reduce_scatter(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
203
204
                reduce_scatter, residual, rms_norm_weights
            )
205
            all_gather = self._all_gather(rmsnorm[1])
206
207
            return all_gather, rmsnorm[2]

208
209
210
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
211
212


213
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
214
215
216
217
    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
218
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
219
220
221
222
223
224
225
226
227
228
229
230

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
231
        ) -> tuple[torch.Tensor, torch.Tensor]:
232
233
            all_reduce = self._all_reduce(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
234
235
                all_reduce, residual, rms_norm_weights
            )
236
237
238
239
240
241
            return rmsnorm[1]

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
242
        ) -> tuple[torch.Tensor, torch.Tensor]:
243
244
            reduce_scatter = self._reduce_scatter(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
245
246
                reduce_scatter, residual, rms_norm_weights
            )
247
248
249
            normalized = self._all_gather(rmsnorm[1])
            return normalized

250
251
252
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
253
254
255
256
257
258


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
259
260
261
    def __init__(
        self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
    ):
262
263
264
265
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
266
267
        rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
        quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
        return [input, rmsnorm_result, quant_result, weight, scale]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = self._all_reduce(input)
            static_fp8 = self._functional_rmsnorm_then_quant(
282
283
                rmsnorm_result, quant_result, all_reduce, weight, scale
            )
284
285
286
287
288
289
290
291
292
293
294
            return static_fp8[1], all_reduce

        def replacement(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            reduce_scatter = self._reduce_scatter(input)

295
296
297
            rmsnorm_result = torch.empty_like(
                reduce_scatter, dtype=rmsnorm_result.dtype
            )
298
299
            quant_result = torch.empty_like(
                rmsnorm_result,  # Output of RMSNorm
300
301
                dtype=quant_result.dtype,
            )
302
            static_fp8 = self._functional_rmsnorm_then_quant(
303
304
                rmsnorm_result, quant_result, reduce_scatter, weight, scale
            )
305
306
307
308
            all_gather = self._all_gather(static_fp8[1])

            return all_gather, reduce_scatter

309
310
311
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
312
313
314


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
315
316
317
    def __init__(
        self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
    ):
318
319
320
321
322
323
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
324
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

        return [
            result,
            residual,
            mm_1,
            rms_norm_weights,
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
345
346
347
348
349
            static_fp8, rmsnorm_residual_out = (
                self._functional_fused_add_rmsnorm_then_quant(  # noqa: E501
                    result, all_reduce, residual, rms_norm_weights, scale
                )
            )
350
351
352
353
354
355
356
357
358
359
            return static_fp8[1], rmsnorm_residual_out

        def replacement(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            reduce_scatter = self._reduce_scatter(mm_1)
360
361
362
363
364
365
            quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
            static_fp8, rmsnorm_residual_out = (
                self._functional_fused_add_rmsnorm_then_quant(  # noqa: E501
                    quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
                )
            )
366
367
368
            all_gather = self._all_gather(static_fp8[1])
            return all_gather, rmsnorm_residual_out

369
370
371
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
372
373
374


class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
375
376
377
    def __init__(
        self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
    ):
378
379
380
381
382
383
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
384
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

        return [
            result,
            residual,
            mm_1,
            rms_norm_weights,
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
            static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
406
407
                result, all_reduce, residual, rms_norm_weights, scale
            )
408
            return static_fp8[1]
409

410
411
412
413
414
415
416
417
        def replacement(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            reduce_scatter = self._reduce_scatter(mm_1)
418
            quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
419
            static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
420
421
                quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
            )
422
            normalized = self._all_gather(static_fp8[1])
423
424
            return normalized

425
426
427
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
428
429


430
class SequenceParallelismPass(VllmPatternMatcherPass):
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
    """
449

450
    @enable_fake_mode
451
452
453
454
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
455
456
            pass_name="sequence_parallelism_pass"
        )
457

458
        for epsilon in [1e-5, 1e-6]:
459
460
461
            # RMSNorm + Static FP8 quantization patterns
            fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
            FirstAllReduceRMSNormStaticFP8Pattern(
462
463
                epsilon, self.model_dtype, self.device, fp8_quant_op
            ).register(self.patterns)
464
            MiddleAllReduceRMSNormStaticFP8Pattern(
465
466
                epsilon, self.model_dtype, self.device, fp8_quant_op
            ).register(self.patterns)
467
            LastAllReduceRMSNormStaticFP8Pattern(
468
469
                epsilon, self.model_dtype, self.device, fp8_quant_op
            ).register(self.patterns)
470
471

            # Normal RMSNorm patterns
472
473
474
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
475

476
477
478
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
479

480
481
482
            LastAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
483
        self.dump_patterns(config, self.patterns)
484

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    def is_applicable(self, shape: int | None) -> bool:
        # When sequence parallelism is enabled, the residual tensor from RMSNorm
        # needs to be split along the sequence dimension. However, this dimension
        # is symbolic during piecewise compilation, and splitting symbolic shapes
        # is not supported.
        #
        # This pass is therefore only applied when the sequence dimension is
        # concrete:
        # 1. In full-graph compilation mode (no Dynamo splitting ops are used).
        #   For this case we always pad num_tokens to be a multiple of
        #   tensor_parallel_size, so there's no need to check shape % tp_size == 0.
        # 2. For specific shape provided during compilation (e.g., from
        #    `compile_sizes`), which must be divisible by the tensor-parallel
        #    size.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
504
505
506
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

507
    @VllmInductorPass.time_and_log
508
    def __call__(self, graph: fx.Graph):
509
510
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)