sequence_parallelism.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
6
from collections.abc import Callable, Sequence
from typing import Any
7

8
9
10
11
12
13
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
14
from vllm.config.compilation import Range
15
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
16
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
19
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
21
from vllm.platforms import current_platform
22

23
from .inductor_pass import enable_fake_mode
24
25
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
26
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
27
28
29
30

logger = init_logger(__name__)


31
32
33
def get_first_out_wrapper(
    fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
34
    @functools.wraps(fn)
35
    def wrapper(*args: Any) -> torch.Tensor:
36
        return fn(*args)[0]
37

38
    return wrapper
39

40

41
class _SequenceParallelPatternHelper:
42
43
    """Helper for sequence parallelism patterns."""

44
45
46
47
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
48
49
        device: str | None,
    ) -> None:
50
51
52
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
53
54
55
56
57
58
59
60
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
61
62
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
63
64
65

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
66
67
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
68
69
70


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
71
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
72
73
74
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)

75
    def get_inputs(self) -> list[torch.Tensor]:
76
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
77
78
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

79
        return [input, arg3_1]
80

81
    def register(self, pm_pass: PatternMatcherPass) -> None:
82
        def pattern(
83
            input: torch.Tensor,
84
            arg3_1: torch.Tensor,
85
        ) -> tuple[torch.Tensor, torch.Tensor]:
86
            all_reduce = self._all_reduce(input)
87
            rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
88

89
            return rmsnorm, all_reduce
90
91

        def replacement(
92
            input: torch.Tensor,
93
            arg3_1: torch.Tensor,
94
        ) -> tuple[torch.Tensor, torch.Tensor]:
95
            reduce_scatter = self._reduce_scatter(input)
96

97
98
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
            all_gather = self._all_gather(rmsnorm)
99
100
            return all_gather, reduce_scatter

101
102
103
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
104
105


106
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
107
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
108
109
110
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

111
    def get_inputs(self) -> list[torch.Tensor]:
112
113
114
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
115
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
116
117
118
119
120
121
122

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

123
    def register(self, pm_pass: PatternMatcherPass) -> None:
124
125
126
127
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
128
        ) -> tuple[torch.Tensor, torch.Tensor]:
129
            all_reduce = self._all_reduce(mm_1)
130
131
            rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
            return rmsnorm[0], rmsnorm[1]
132
133
134
135
136

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
137
        ) -> tuple[torch.Tensor, torch.Tensor]:
138
139
140
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # once the seqpar pattern with the previous rmsnorm is replaced
141
            reduce_scatter = self._reduce_scatter(mm_1)
142
143
144
145
146
147
            residual = residual[0 : reduce_scatter.size(0), ...]
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
            all_gather = self._all_gather(rmsnorm[0])
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, rmsnorm[1]
148

149
150
151
152
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
        pm.register_replacement(
153
154
155
156
157
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
158
        )
159
160
161
162
163
164


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
165
    def __init__(
166
167
168
        self,
        epsilon: float,
        dtype: torch.dtype,
169
        device: str | None,
170
    ) -> None:
171
172
173
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
174

175
    def get_inputs(self) -> list[torch.Tensor]:
176
177
178
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
179
        return [input, weight, scale]
180

181
    def register(self, pm_pass: PatternMatcherPass) -> None:
182
183
184
185
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
186
        ) -> tuple[torch.Tensor, torch.Tensor]:
187
            all_reduce = self._all_reduce(input)
188
189
190
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
191
192
193
194
195

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
196
        ) -> tuple[torch.Tensor, torch.Tensor]:
197
            reduce_scatter = self._reduce_scatter(input)
198
199
200
            rms = self.rmsnorm_matcher(reduce_scatter, weight)
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
201
202
203

            return all_gather, reduce_scatter

204
205
206
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
207
208
209


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
210
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
211
212
213
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
214

215
    def get_inputs(self) -> list[torch.Tensor]:
216
217
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
218
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
219
220
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

221
        return [residual, mm_1, rms_norm_weights, scale]
222

223
    def register(self, pm_pass: PatternMatcherPass) -> None:
224
225
226
227
228
229
230
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
231
232
            rms, residual_out = self.rmsnorm_matcher(
                all_reduce, rms_norm_weights, residual
233
            )
234
235
            quant, _ = self.quant_matcher(rms, scale)
            return quant, residual_out
236
237
238
239
240
241
242

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
243
244
245
246
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # add a temporary slice which will become a noop
            # once the seqpar pattern with the previous rmsnorm is replaced
247
            reduce_scatter = self._reduce_scatter(mm_1)
248
249
250
            residual = residual[0 : reduce_scatter.size(0), ...]
            rms, residual_out = self.rmsnorm_matcher(
                reduce_scatter, rms_norm_weights, residual
251
            )
252
253
254
255
256
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, residual_out
257

258
259
260
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
261

262
        pm.register_replacement(
263
264
265
266
267
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
268
        )
269
270


271
class SequenceParallelismPass(VllmPatternMatcherPass):
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305


    This pass splits up the residual tensor across TP ranks and hence divides its size.
    Because the pattern matcher starts at the end of the graph, the replacement
    contains a slice that temporarily conforms the input residual to the correct size.
    After all patterns have been matched, we use a NoOpEliminationPass to clean up
    what have now become no-op slices.

    Note that an older version of the pass did not need this as it operated only on
    custom rms_norm and fused_rms_norm_add custom ops which did not complain about
    mismatched shapes during replacement. So this approach has the same assumption that
    correctness is only maintained if all rms_norm operations are split across ranks.

    Correctness-wise, this is approach strictly better than before - before,
    the graph was incorrect semantically and shape-wise during the pass.
    With this approach there's only semantic incorrectness during the pass.
    Both approaches restore a correct graph once all patterns are matched.
306
    """
307

308
    @enable_fake_mode
309
    def __init__(self, config: VllmConfig) -> None:
310
311
        super().__init__(config)

312
        # Used to clean up redundant views created temporarily
313
314
315
316
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

317
        self.patterns: PatternMatcherPass = PatternMatcherPass(
318
319
            pass_name="sequence_parallelism_pass"
        )
320

321
        for epsilon in [1e-5, 1e-6]:
322
323
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
324
                epsilon, self.model_dtype, self.device
325
            ).register(self.patterns)
326
            MiddleAllReduceRMSNormStaticFP8Pattern(
327
                epsilon, self.model_dtype, self.device
328
            ).register(self.patterns)
329
330

            # Normal RMSNorm patterns
331
332
333
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
334

335
336
337
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
338

339
        self.dump_patterns(config, self.patterns)
340

341
    def is_applicable_for_range(self, compile_range: Range) -> bool:
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # When sequence parallelism is enabled, the residual tensor from RMSNorm
        # needs to be split along the sequence dimension. However, this dimension
        # is symbolic during piecewise compilation, and splitting symbolic shapes
        # is not supported.
        #
        # This pass is therefore only applied when the sequence dimension is
        # concrete:
        # 1. In full-graph compilation mode (no Dynamo splitting ops are used).
        #   For this case we always pad num_tokens to be a multiple of
        #   tensor_parallel_size, so there's no need to check shape % tp_size == 0.
        # 2. For specific shape provided during compilation (e.g., from
        #    `compile_sizes`), which must be divisible by the tensor-parallel
        #    size.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
360
        tp_size = get_tensor_model_parallel_world_size()
361
        return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
362

363
    @VllmInductorPass.time_and_log
364
    def __call__(self, graph: fx.Graph) -> None:
365
366
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
367
368
        # Clean up reshape nodes
        self.noop_cleanup(graph)