sequence_parallelism.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Optional
4
5
6
7
8
9
10
11

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
12
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
13
from vllm.logger import init_logger
14
from vllm.platforms import current_platform
15

16
from .inductor_pass import enable_fake_mode
17
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
18
19
20
21

logger = init_logger(__name__)


22
23
class _RMSNormAndQuantOpHelper:
    """Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
24

25
26
27
28
29
30
31
32
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        quant_op: Optional[torch._ops.OpOverload] = None,
        **kwargs,
    ):
33
34
35
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
36
37
38
39
40
41
42
43
        self.quant_op = quant_op

    def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
        return torch.ops.higher_order.auto_functionalized(
            torch.ops._C.rms_norm.default,
            result=result_buffer,
            input=input_tensor,
            weight=weight_tensor,
44
45
            epsilon=self.epsilon,
        )
46

47
48
49
    def _functional_fused_add_rmsnorm(
        self, input_tensor, residual_tensor, weight_tensor
    ):
50
51
52
53
54
        return torch.ops.higher_order.auto_functionalized(
            torch.ops._C.fused_add_rms_norm.default,
            input=input_tensor,
            residual=residual_tensor,
            weight=weight_tensor,
55
56
57
58
59
60
61
62
63
64
65
            epsilon=self.epsilon,
        )

    def _functional_rmsnorm_then_quant(
        self,
        rmsnorm_result_buffer,
        quant_result_buffer,
        input_tensor,
        weight_tensor,
        scale_tensor,
    ):
66
67
68
69
        if self.quant_op is None:
            raise RuntimeError(
                "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
            )
70
71
72
        rmsnorm_out_tuple = self._functional_rmsnorm(
            rmsnorm_result_buffer, input_tensor, weight_tensor
        )
73
74
75
76
        quant_out_tuple = torch.ops.higher_order.auto_functionalized(
            self.quant_op,
            result=quant_result_buffer,
            input=rmsnorm_out_tuple[1],
77
78
            scale=scale_tensor,
        )
79
80
        return quant_out_tuple

81
82
83
84
85
86
87
88
    def _functional_fused_add_rmsnorm_then_quant(
        self,
        quant_result_buffer,
        input_tensor,
        residual_tensor,
        weight_tensor,
        scale_tensor,
    ):
89
90
91
92
93
        if self.quant_op is None:
            raise RuntimeError(
                "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
            )
        fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
94
95
            input_tensor, residual_tensor, weight_tensor
        )
96
97
98
99
        quant_out_tuple = torch.ops.higher_order.auto_functionalized(
            self.quant_op,
            result=quant_result_buffer,
            input=fused_add_rmsnorm_out_tuple[1],
100
101
            scale=scale_tensor,
        )
102
103
104
105
106
107
        return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]


class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
    """Helper for sequence parallelism patterns."""

108
109
110
111
112
113
114
115
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
        quant_op: Optional[torch._ops.OpOverload] = None,
        **kwargs,
    ):
116
117
118
119
120
121
122
123
124
        super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
125
126
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
127
128
129

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
130
131
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
132
133
134


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
135
    def get_inputs(self):
136
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
137
138
139
        permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

140
        return [input, permute, arg3_1]
141
142
143

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
144
            input: torch.Tensor,
145
146
147
            permute: torch.Tensor,
            arg3_1: torch.Tensor,
        ):
148
149
            all_reduce = self._all_reduce(input)
            rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
150
151
152
153

            return rmsnorm[1], all_reduce

        def replacement(
154
            input: torch.Tensor,
155
156
157
            permute: torch.Tensor,
            arg3_1: torch.Tensor,
        ):
158
            reduce_scatter = self._reduce_scatter(input)
159
160

            rmsnorm_result = torch.empty_like(reduce_scatter)
161
            rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
162

163
            all_gather = self._all_gather(rmsnorm[1])
164
165
166

            return all_gather, reduce_scatter

167
168
169
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
170
171


172
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
173
174
175
176
    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
177
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
178
179
180
181
182
183
184
185
186
187
188
189

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
190
        ) -> tuple[torch.Tensor, torch.Tensor]:
191
192
            all_reduce = self._all_reduce(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
193
194
                all_reduce, residual, rms_norm_weights
            )
195
196
197
198
199
200
            return rmsnorm[1], rmsnorm[2]

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
201
        ) -> tuple[torch.Tensor, torch.Tensor]:
202
203
            reduce_scatter = self._reduce_scatter(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
204
205
                reduce_scatter, residual, rms_norm_weights
            )
206
            all_gather = self._all_gather(rmsnorm[1])
207
208
            return all_gather, rmsnorm[2]

209
210
211
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
212
213


214
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
215
216
217
218
    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
219
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
220
221
222
223
224
225
226
227
228
229
230
231

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
232
        ) -> tuple[torch.Tensor, torch.Tensor]:
233
234
            all_reduce = self._all_reduce(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
235
236
                all_reduce, residual, rms_norm_weights
            )
237
238
239
240
241
242
            return rmsnorm[1]

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
243
        ) -> tuple[torch.Tensor, torch.Tensor]:
244
245
            reduce_scatter = self._reduce_scatter(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
246
247
                reduce_scatter, residual, rms_norm_weights
            )
248
249
250
            normalized = self._all_gather(rmsnorm[1])
            return normalized

251
252
253
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
254
255
256
257
258
259


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
260
261
262
    def __init__(
        self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
    ):
263
264
265
266
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
267
268
        rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
        quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
        return [input, rmsnorm_result, quant_result, weight, scale]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = self._all_reduce(input)
            static_fp8 = self._functional_rmsnorm_then_quant(
283
284
                rmsnorm_result, quant_result, all_reduce, weight, scale
            )
285
286
287
288
289
290
291
292
293
294
295
            return static_fp8[1], all_reduce

        def replacement(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            reduce_scatter = self._reduce_scatter(input)

296
297
298
            rmsnorm_result = torch.empty_like(
                reduce_scatter, dtype=rmsnorm_result.dtype
            )
299
300
            quant_result = torch.empty_like(
                rmsnorm_result,  # Output of RMSNorm
301
302
                dtype=quant_result.dtype,
            )
303
            static_fp8 = self._functional_rmsnorm_then_quant(
304
305
                rmsnorm_result, quant_result, reduce_scatter, weight, scale
            )
306
307
308
309
            all_gather = self._all_gather(static_fp8[1])

            return all_gather, reduce_scatter

310
311
312
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
313
314
315


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
316
317
318
    def __init__(
        self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
    ):
319
320
321
322
323
324
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
325
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

        return [
            result,
            residual,
            mm_1,
            rms_norm_weights,
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
346
347
348
349
350
            static_fp8, rmsnorm_residual_out = (
                self._functional_fused_add_rmsnorm_then_quant(  # noqa: E501
                    result, all_reduce, residual, rms_norm_weights, scale
                )
            )
351
352
353
354
355
356
357
358
359
360
            return static_fp8[1], rmsnorm_residual_out

        def replacement(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            reduce_scatter = self._reduce_scatter(mm_1)
361
362
363
364
365
366
            quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
            static_fp8, rmsnorm_residual_out = (
                self._functional_fused_add_rmsnorm_then_quant(  # noqa: E501
                    quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
                )
            )
367
368
369
            all_gather = self._all_gather(static_fp8[1])
            return all_gather, rmsnorm_residual_out

370
371
372
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
373
374
375


class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
376
377
378
    def __init__(
        self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
    ):
379
380
381
382
383
384
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
385
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

        return [
            result,
            residual,
            mm_1,
            rms_norm_weights,
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
            static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
407
408
                result, all_reduce, residual, rms_norm_weights, scale
            )
409
            return static_fp8[1]
410

411
412
413
414
415
416
417
418
        def replacement(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            reduce_scatter = self._reduce_scatter(mm_1)
419
            quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
420
            static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
421
422
                quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
            )
423
            normalized = self._all_gather(static_fp8[1])
424
425
            return normalized

426
427
428
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
429
430


431
class SequenceParallelismPass(VllmPatternMatcherPass):
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
    """
450

451
    @enable_fake_mode
452
453
454
455
    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
456
457
            pass_name="sequence_parallelism_pass"
        )
458

459
        for epsilon in [1e-5, 1e-6]:
460
461
462
            # RMSNorm + Static FP8 quantization patterns
            fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
            FirstAllReduceRMSNormStaticFP8Pattern(
463
464
                epsilon, self.model_dtype, self.device, fp8_quant_op
            ).register(self.patterns)
465
            MiddleAllReduceRMSNormStaticFP8Pattern(
466
467
                epsilon, self.model_dtype, self.device, fp8_quant_op
            ).register(self.patterns)
468
            LastAllReduceRMSNormStaticFP8Pattern(
469
470
                epsilon, self.model_dtype, self.device, fp8_quant_op
            ).register(self.patterns)
471
472

            # Normal RMSNorm patterns
473
474
475
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
476

477
478
479
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
480

481
482
483
            LastAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
484
        self.dump_patterns(config, self.patterns)
485
486
487
488
489

    def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

490
    @VllmInductorPass.time_and_log
491
    def __call__(self, graph: fx.Graph):
492
493
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)