"vllm/vscode:/vscode.git/clone" did not exist on "cbc40128eb16ae0045f5aa6d6ee2ff2dda803e23"
sequence_parallelism.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import functools

6
7
8
9
10
11
12
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
13
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
14
from vllm.logger import init_logger
15
16
17
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
18
from vllm.platforms import current_platform
19

20
from .inductor_pass import enable_fake_mode
21
22
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .noop_elimination import NoOpEliminationPass
23
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
24
25
26
27

logger = init_logger(__name__)


28
29
30
31
def get_first_out_wrapper(fn):
    @functools.wraps(fn)
    def wrapper(*args):
        return fn(*args)[0]
32

33
    return wrapper
34

35

36
class _SequenceParallelPatternHelper:
37
38
    """Helper for sequence parallelism patterns."""

39
40
41
42
43
44
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
    ):
45
46
47
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
48
49
50
51
52
53
54
55
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
56
57
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
58
59
60

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
61
62
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
63
64
65


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
66
67
68
69
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)

70
    def get_inputs(self):
71
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
72
73
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

74
        return [input, arg3_1]
75
76
77

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
78
            input: torch.Tensor,
79
80
            arg3_1: torch.Tensor,
        ):
81
            all_reduce = self._all_reduce(input)
82
            rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
83

84
            return rmsnorm, all_reduce
85
86

        def replacement(
87
            input: torch.Tensor,
88
89
            arg3_1: torch.Tensor,
        ):
90
            reduce_scatter = self._reduce_scatter(input)
91

92
93
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
            all_gather = self._all_gather(rmsnorm)
94
95
            return all_gather, reduce_scatter

96
97
98
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
99
100


101
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
102
103
104
105
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

106
107
108
109
    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
110
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
111
112
113
114
115
116
117
118
119
120
121
122

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
123
        ) -> tuple[torch.Tensor, torch.Tensor]:
124
            all_reduce = self._all_reduce(mm_1)
125
126
            rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
            return rmsnorm[0], rmsnorm[1]
127
128
129
130
131

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
132
        ) -> tuple[torch.Tensor, torch.Tensor]:
133
134
135
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # once the seqpar pattern with the previous rmsnorm is replaced
136
            reduce_scatter = self._reduce_scatter(mm_1)
137
138
139
140
141
142
            residual = residual[0 : reduce_scatter.size(0), ...]
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
            all_gather = self._all_gather(rmsnorm[0])
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, rmsnorm[1]
143

144
145
146
147
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
        pm.register_replacement(
148
149
150
151
152
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
153
        )
154
155
156
157
158
159


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
160
    def __init__(
161
162
163
164
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str,
165
    ):
166
167
168
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
169
170
171
172
173

    def get_inputs(self):
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
174
        return [input, weight, scale]
175
176
177
178
179
180
181
182

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            all_reduce = self._all_reduce(input)
183
184
185
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
186
187
188
189
190
191
192

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ):
            reduce_scatter = self._reduce_scatter(input)
193
194
195
            rms = self.rmsnorm_matcher(reduce_scatter, weight)
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
196
197
198

            return all_gather, reduce_scatter

199
200
201
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
202
203
204


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
205
206
207
208
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
209
210
211
212

    def get_inputs(self):
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
213
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
214
215
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

216
        return [residual, mm_1, rms_norm_weights, scale]
217
218
219
220
221
222
223
224
225

    def register(self, pm_pass: PatternMatcherPass):
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
226
227
            rms, residual_out = self.rmsnorm_matcher(
                all_reduce, rms_norm_weights, residual
228
            )
229
230
            quant, _ = self.quant_matcher(rms, scale)
            return quant, residual_out
231
232
233
234
235
236
237

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
238
239
240
241
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # add a temporary slice which will become a noop
            # once the seqpar pattern with the previous rmsnorm is replaced
242
            reduce_scatter = self._reduce_scatter(mm_1)
243
244
245
            residual = residual[0 : reduce_scatter.size(0), ...]
            rms, residual_out = self.rmsnorm_matcher(
                reduce_scatter, rms_norm_weights, residual
246
            )
247
248
249
250
251
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, residual_out
252

253
254
255
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
256

257
        pm.register_replacement(
258
259
260
261
262
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
263
        )
264
265


266
class SequenceParallelismPass(VllmPatternMatcherPass):
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300


    This pass splits up the residual tensor across TP ranks and hence divides its size.
    Because the pattern matcher starts at the end of the graph, the replacement
    contains a slice that temporarily conforms the input residual to the correct size.
    After all patterns have been matched, we use a NoOpEliminationPass to clean up
    what have now become no-op slices.

    Note that an older version of the pass did not need this as it operated only on
    custom rms_norm and fused_rms_norm_add custom ops which did not complain about
    mismatched shapes during replacement. So this approach has the same assumption that
    correctness is only maintained if all rms_norm operations are split across ranks.

    Correctness-wise, this is approach strictly better than before - before,
    the graph was incorrect semantically and shape-wise during the pass.
    With this approach there's only semantic incorrectness during the pass.
    Both approaches restore a correct graph once all patterns are matched.
301
    """
302

303
    @enable_fake_mode
304
305
306
    def __init__(self, config: VllmConfig):
        super().__init__(config)

307
        # Used to clean up redundant views created temporarily
308
309
310
311
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

312
        self.patterns: PatternMatcherPass = PatternMatcherPass(
313
314
            pass_name="sequence_parallelism_pass"
        )
315

316
        for epsilon in [1e-5, 1e-6]:
317
318
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
319
                epsilon, self.model_dtype, self.device
320
            ).register(self.patterns)
321
            MiddleAllReduceRMSNormStaticFP8Pattern(
322
                epsilon, self.model_dtype, self.device
323
            ).register(self.patterns)
324
325

            # Normal RMSNorm patterns
326
327
328
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
329

330
331
332
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
333

334
        self.dump_patterns(config, self.patterns)
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    def is_applicable(self, shape: int | None) -> bool:
        # When sequence parallelism is enabled, the residual tensor from RMSNorm
        # needs to be split along the sequence dimension. However, this dimension
        # is symbolic during piecewise compilation, and splitting symbolic shapes
        # is not supported.
        #
        # This pass is therefore only applied when the sequence dimension is
        # concrete:
        # 1. In full-graph compilation mode (no Dynamo splitting ops are used).
        #   For this case we always pad num_tokens to be a multiple of
        #   tensor_parallel_size, so there's no need to check shape % tp_size == 0.
        # 2. For specific shape provided during compilation (e.g., from
        #    `compile_sizes`), which must be divisible by the tensor-parallel
        #    size.
        if (
            not self.compilation_config.splitting_ops
            or self.compilation_config.use_inductor_graph_partition
        ):
            return True
355
356
357
        tp_size = get_tensor_model_parallel_world_size()
        return shape is not None and shape % tp_size == 0

358
    @VllmInductorPass.time_and_log
359
    def __call__(self, graph: fx.Graph):
360
361
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
362
363
        # Clean up reshape nodes
        self.noop_cleanup(graph)