sequence_parallelism.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import functools

6
7
8
9
10
11
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
12
from vllm.config.compilation import Range
13
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
14
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
16
17
18
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
19
from vllm.platforms import current_platform
20

21
from .inductor_pass import enable_fake_mode
22
23
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
24
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
25
26
27
28

logger = init_logger(__name__)


29
30
31
32
def get_first_out_wrapper(fn):
    @functools.wraps(fn)
    def wrapper(*args):
        return fn(*args)[0]
33

34
    return wrapper
35

36

37
class _SequenceParallelPatternHelper:
38
39
    """Helper for sequence parallelism patterns."""

40
41
42
43
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
44
45
        device: str | None,
    ) -> None:
46
47
48
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
49
50
51
52
53
54
55
56
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
57
58
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
59
60
61

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
62
63
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
64
65
66


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
67
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
68
69
70
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)

71
    def get_inputs(self):
72
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
73
74
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

75
        return [input, arg3_1]
76

77
    def register(self, pm_pass: PatternMatcherPass) -> None:
78
        def pattern(
79
            input: torch.Tensor,
80
81
            arg3_1: torch.Tensor,
        ):
82
            all_reduce = self._all_reduce(input)
83
            rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
84

85
            return rmsnorm, all_reduce
86
87

        def replacement(
88
            input: torch.Tensor,
89
90
            arg3_1: torch.Tensor,
        ):
91
            reduce_scatter = self._reduce_scatter(input)
92

93
94
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
            all_gather = self._all_gather(rmsnorm)
95
96
            return all_gather, reduce_scatter

97
98
99
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
100
101


102
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
103
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
104
105
106
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

107
108
109
110
    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
111
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
112
113
114
115
116
117
118
119
120
121
122
123

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
124
        ) -> tuple[torch.Tensor, torch.Tensor]:
125
            all_reduce = self._all_reduce(mm_1)
126
127
            rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
            return rmsnorm[0], rmsnorm[1]
128
129
130
131
132

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
133
        ) -> tuple[torch.Tensor, torch.Tensor]:
134
135
136
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # once the seqpar pattern with the previous rmsnorm is replaced
137
            reduce_scatter = self._reduce_scatter(mm_1)
138
139
140
141
142
143
            residual = residual[0 : reduce_scatter.size(0), ...]
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
            all_gather = self._all_gather(rmsnorm[0])
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, rmsnorm[1]
144

145
146
147
148
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
        pm.register_replacement(
149
150
151
152
153
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
154
        )
155
156
157
158
159
160


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
161
    def __init__(
162
163
164
        self,
        epsilon: float,
        dtype: torch.dtype,
165
        device: str | None,
166
    ):
167
168
169
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
170
171
172
173
174

    def get_inputs(self):
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
175
        return [input, weight, scale]
176
177
178
179
180
181
182
183

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = self._all_reduce(input)
184
185
186
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
187
188
189
190
191
192
193

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            reduce_scatter = self._reduce_scatter(input)
194
195
196
            rms = self.rmsnorm_matcher(reduce_scatter, weight)
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
197
198
199

            return all_gather, reduce_scatter

200
201
202
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
203
204
205


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
206
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
207
208
209
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
210
211
212
213

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
214
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
215
216
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

217
        return [residual, mm_1, rms_norm_weights, scale]
218
219
220
221
222
223
224
225
226

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
227
228
            rms, residual_out = self.rmsnorm_matcher(
                all_reduce, rms_norm_weights, residual
229
            )
230
231
            quant, _ = self.quant_matcher(rms, scale)
            return quant, residual_out
232
233
234
235
236
237
238

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
239
240
241
242
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # add a temporary slice which will become a noop
            # once the seqpar pattern with the previous rmsnorm is replaced
243
            reduce_scatter = self._reduce_scatter(mm_1)
244
245
246
            residual = residual[0 : reduce_scatter.size(0), ...]
            rms, residual_out = self.rmsnorm_matcher(
                reduce_scatter, rms_norm_weights, residual
247
            )
248
249
250
251
252
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, residual_out
253

254
255
256
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
257

258
        pm.register_replacement(
259
260
261
262
263
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
264
        )
265
266


267
class SequenceParallelismPass(VllmPatternMatcherPass):
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301


    This pass splits up the residual tensor across TP ranks and hence divides its size.
    Because the pattern matcher starts at the end of the graph, the replacement
    contains a slice that temporarily conforms the input residual to the correct size.
    After all patterns have been matched, we use a NoOpEliminationPass to clean up
    what have now become no-op slices.

    Note that an older version of the pass did not need this as it operated only on
    custom rms_norm and fused_rms_norm_add custom ops which did not complain about
    mismatched shapes during replacement. So this approach has the same assumption that
    correctness is only maintained if all rms_norm operations are split across ranks.

    Correctness-wise, this is approach strictly better than before - before,
    the graph was incorrect semantically and shape-wise during the pass.
    With this approach there's only semantic incorrectness during the pass.
    Both approaches restore a correct graph once all patterns are matched.
302
    """
303

304
    @enable_fake_mode
305
306
307
    def __init__(self, config: VllmConfig):
        super().__init__(config)

308
        # Used to clean up redundant views created temporarily
309
310
311
312
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

313
        self.patterns: PatternMatcherPass = PatternMatcherPass(
314
315
            pass_name="sequence_parallelism_pass"
        )
316

317
        for epsilon in [1e-5, 1e-6]:
318
319
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
320
                epsilon, self.model_dtype, self.device
321
            ).register(self.patterns)
322
            MiddleAllReduceRMSNormStaticFP8Pattern(
323
                epsilon, self.model_dtype, self.device
324
            ).register(self.patterns)
325
326

            # Normal RMSNorm patterns
327
328
329
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
330

331
332
333
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
334

335
        self.dump_patterns(config, self.patterns)
336

337
    def is_applicable_for_range(self, compile_range: Range) -> bool:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        # When sequence parallelism is enabled, the residual tensor from RMSNorm
        # needs to be split along the sequence dimension. However, this dimension
        # is symbolic during piecewise compilation, and splitting symbolic shapes
        # is not supported.
        #
        # This pass is therefore only applied when the sequence dimension is
        # concrete:
        # 1. In full-graph compilation mode (no Dynamo splitting ops are used).
        #   For this case we always pad num_tokens to be a multiple of
        #   tensor_parallel_size, so there's no need to check shape % tp_size == 0.
        # 2. For specific shape provided during compilation (e.g., from
        #    `compile_sizes`), which must be divisible by the tensor-parallel
        #    size.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
356
        tp_size = get_tensor_model_parallel_world_size()
357
        return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
358

359
    @VllmInductorPass.time_and_log
360
    def __call__(self, graph: fx.Graph):
361
362
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
363
364
        # Clean up reshape nodes
        self.noop_cleanup(graph)