sequence_parallelism.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Optional
4
5
6
7
8
9
10
11
12
13
14

import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
    get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
17
18
19
20
21

from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)


22
23
class _RMSNormAndQuantOpHelper:
    """Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
24

25
26
27
28
29
30
    def __init__(self,
                 epsilon: float,
                 dtype: torch.dtype,
                 device: str,
                 quant_op: Optional[torch._ops.OpOverload] = None,
                 **kwargs):
31
32
33
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.quant_op = quant_op

    def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
        return torch.ops.higher_order.auto_functionalized(
            torch.ops._C.rms_norm.default,
            result=result_buffer,
            input=input_tensor,
            weight=weight_tensor,
            epsilon=self.epsilon)

    def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor,
                                      weight_tensor):
        return torch.ops.higher_order.auto_functionalized(
            torch.ops._C.fused_add_rms_norm.default,
            input=input_tensor,
            residual=residual_tensor,
            weight=weight_tensor,
            epsilon=self.epsilon)

    def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer,
                                       quant_result_buffer, input_tensor,
                                       weight_tensor, scale_tensor):
        if self.quant_op is None:
            raise RuntimeError(
                "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
            )
        rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer,
                                                     input_tensor,
                                                     weight_tensor)
        quant_out_tuple = torch.ops.higher_order.auto_functionalized(
            self.quant_op,
            result=quant_result_buffer,
            input=rmsnorm_out_tuple[1],
            scale=scale_tensor)
        return quant_out_tuple

    def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer,
                                                 input_tensor, residual_tensor,
                                                 weight_tensor, scale_tensor):
        if self.quant_op is None:
            raise RuntimeError(
                "_RMSNormAndQuantOpHelper was not initialized with a quant_op."
            )
        fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
            input_tensor, residual_tensor, weight_tensor)
        quant_out_tuple = torch.ops.higher_order.auto_functionalized(
            self.quant_op,
            result=quant_result_buffer,
            input=fused_add_rmsnorm_out_tuple[1],
            scale=scale_tensor)
        return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]


class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
    """Helper for sequence parallelism patterns."""

    def __init__(self,
                 epsilon: float,
                 dtype: torch.dtype,
                 device: str,
                 quant_op: Optional[torch._ops.OpOverload] = None,
                 **kwargs):
        super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
            x,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp_group.unique_name)

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
            x,
            dim=0,
            world_size=self.tp_size,
            group_name=self.tp_group.unique_name)


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
119
120

    def get_inputs(self):
121
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
122
123
124
        permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

125
        return [input, permute, arg3_1]
126
127
128
129

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
130
            input: torch.Tensor,
131
132
133
            permute: torch.Tensor,
            arg3_1: torch.Tensor,
        ):
134
135
            all_reduce = self._all_reduce(input)
            rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
136
137
138
139

            return rmsnorm[1], all_reduce

        def replacement(
140
            input: torch.Tensor,
141
142
143
            permute: torch.Tensor,
            arg3_1: torch.Tensor,
        ):
144
            reduce_scatter = self._reduce_scatter(input)
145
146

            rmsnorm_result = torch.empty_like(reduce_scatter)
147
148
            rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter,
                                               arg3_1)
149

150
            all_gather = self._all_gather(rmsnorm[1])
151
152
153
154
155
156
157

            return all_gather, reduce_scatter

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


158
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        rms_norm_weights = torch.empty([4, 4],
                                       device=self.device,
                                       dtype=self.dtype)

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
180
        ) -> tuple[torch.Tensor, torch.Tensor]:
181
182
183
            all_reduce = self._all_reduce(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
                all_reduce, residual, rms_norm_weights)
184
185
186
187
188
189
            return rmsnorm[1], rmsnorm[2]

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
190
        ) -> tuple[torch.Tensor, torch.Tensor]:
191
192
193
194
            reduce_scatter = self._reduce_scatter(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
                reduce_scatter, residual, rms_norm_weights)
            all_gather = self._all_gather(rmsnorm[1])
195
196
197
198
199
200
            return all_gather, rmsnorm[2]

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


201
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        rms_norm_weights = torch.empty([4, 4],
                                       device=self.device,
                                       dtype=self.dtype)

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
223
        ) -> tuple[torch.Tensor, torch.Tensor]:
224
225
226
            all_reduce = self._all_reduce(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
                all_reduce, residual, rms_norm_weights)
227
228
229
230
231
232
            return rmsnorm[1]

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
233
        ) -> tuple[torch.Tensor, torch.Tensor]:
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
            reduce_scatter = self._reduce_scatter(mm_1)
            rmsnorm = self._functional_fused_add_rmsnorm(
                reduce_scatter, residual, rms_norm_weights)
            normalized = self._all_gather(rmsnorm[1])
            return normalized

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 op: torch._ops.OpOverload):
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
        rmsnorm_result = torch.empty([1, 8, 4],
                                     device=self.device,
                                     dtype=self.dtype)
        quant_result = torch.empty([1, 8, 4],
                                   device=self.device,
                                   dtype=FP8_DTYPE)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
        return [input, rmsnorm_result, quant_result, weight, scale]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = self._all_reduce(input)
            static_fp8 = self._functional_rmsnorm_then_quant(
                rmsnorm_result, quant_result, all_reduce, weight, scale)
            return static_fp8[1], all_reduce

        def replacement(
            input: torch.Tensor,
            rmsnorm_result: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            reduce_scatter = self._reduce_scatter(input)

            rmsnorm_result = torch.empty_like(reduce_scatter,
                                              dtype=rmsnorm_result.dtype)
            quant_result = torch.empty_like(
                rmsnorm_result,  # Output of RMSNorm
                dtype=quant_result.dtype)
            static_fp8 = self._functional_rmsnorm_then_quant(
                rmsnorm_result, quant_result, reduce_scatter, weight, scale)
            all_gather = self._all_gather(static_fp8[1])

            return all_gather, reduce_scatter

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 op: torch._ops.OpOverload):
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        rms_norm_weights = torch.empty([4, 4],
                                       device=self.device,
                                       dtype=self.dtype)
        result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

        return [
            result,
            residual,
            mm_1,
            rms_norm_weights,
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass):

        def pattern(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
            static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant(  # noqa: E501
                result, all_reduce, residual, rms_norm_weights, scale)
            return static_fp8[1], rmsnorm_residual_out

        def replacement(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            reduce_scatter = self._reduce_scatter(mm_1)
            quant_result_buf = torch.empty_like(reduce_scatter,
                                                dtype=result.dtype)
            static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant(  # noqa: E501
                quant_result_buf, reduce_scatter, residual, rms_norm_weights,
                scale)
            all_gather = self._all_gather(static_fp8[1])
            return all_gather, rmsnorm_residual_out

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):

    def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
                 op: torch._ops.OpOverload):
        super().__init__(epsilon, dtype, device, quant_op=op)

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        rms_norm_weights = torch.empty([4, 4],
                                       device=self.device,
                                       dtype=self.dtype)
        result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

        return [
            result,
            residual,
            mm_1,
            rms_norm_weights,
            scale,
        ]

    def register(self, pm_pass: PatternMatcherPass):
386

387
388
389
390
391
392
393
394
395
396
397
        def pattern(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
            static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
                result, all_reduce, residual, rms_norm_weights, scale)
            return static_fp8[1]
398

399
400
401
402
403
404
405
406
407
408
409
410
411
412
        def replacement(
            result: torch.Tensor,
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            reduce_scatter = self._reduce_scatter(mm_1)
            quant_result_buf = torch.empty_like(reduce_scatter,
                                                dtype=result.dtype)
            static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
                quant_result_buf, reduce_scatter, residual, rms_norm_weights,
                scale)
            normalized = self._all_gather(static_fp8[1])
413
414
415
416
417
418
419
            return normalized

        pm.register_replacement(pattern, replacement, self.get_inputs(),
                                pm.fwd_only, pm_pass)


class SequenceParallelismPass(VllmInductorPass):
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
    """
438
439
440
441
442
443

    def __init__(self, config: VllmConfig):
        super().__init__(config)

        self.patterns: PatternMatcherPass = PatternMatcherPass(
            pass_name="sequence_parallelism_pass")
444

445
        for epsilon in [1e-5, 1e-6]:
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            # RMSNorm + Static FP8 quantization patterns
            fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
            FirstAllReduceRMSNormStaticFP8Pattern(
                epsilon, self.model_dtype, self.device,
                fp8_quant_op).register(self.patterns)
            MiddleAllReduceRMSNormStaticFP8Pattern(
                epsilon, self.model_dtype, self.device,
                fp8_quant_op).register(self.patterns)
            LastAllReduceRMSNormStaticFP8Pattern(
                epsilon, self.model_dtype, self.device,
                fp8_quant_op).register(self.patterns)

            # Normal RMSNorm patterns
            FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
                                         self.device).register(self.patterns)
461

462
            MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
463
464
                                          self.device).register(self.patterns)

465
            LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
466
                                        self.device).register(self.patterns)
467

468
469
470
471
472
473
474
475
476
            # WARNING: This is a hack to clear the pattern matcher cache
            # and allow multiple values of epsilon.
            torch._inductor.pattern_matcher._seen_patterns.clear()

    def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

    def __call__(self, graph: fx.Graph):
477
        self.begin()
478
479
        self.dump_graph(graph, "before_sequence_parallelism_pass")
        count = self.patterns.apply(graph)
480
        logger.debug("Replaced %s patterns with sequence parallelism", count)
481
        self.dump_graph(graph, "after_sequence_parallelism_pass")
482
        self.end_and_log()