sequence_parallelism.py 17.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
6
from collections.abc import Callable, Sequence
from typing import Any
7

8
9
10
11
12
13
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
14
from vllm.config.utils import Range
15
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
16
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
19
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
21

22
23
24
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
25
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
26
27
28

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
    90: 8192,  # H100: only for models with hidden_size >= 8192
}

# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
    90: 8,  # 8MB per GPU for H100
}


def get_sequence_parallelism_threshold(
    hidden_size: int,
    tp_size: int,
    element_size: int,
) -> int | None:
    """
    Calculate the minimum token threshold for applying sequence parallelism.

    Returns None if sequence parallelism should not be applied based on model size.

    Branching logic based on device capability:
    - Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
    - If not, returns None (SP disabled for small models on this device)
    - If yes, calculates threshold based on per-GPU size

    Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
             (hidden_size * element_size)
    """
    from vllm.platforms import current_platform

    if not current_platform.is_cuda():
        return None

    capability = current_platform.get_device_capability()
    if capability is None:
        return None
    device_capability = capability.to_int()

    # Check if device has configured thresholds
    min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
    min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)

    if min_hidden_size is None or min_per_gpu_size_mb is None:
        return None

    # Only apply sequence parallelism for models meeting the size threshold
    if hidden_size < min_hidden_size:
        return None

    MiB = 1024 * 1024
    min_size = min_per_gpu_size_mb * MiB * tp_size
    return int(min_size // (hidden_size * element_size))

86

87
88
89
def get_first_out_wrapper(
    fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
90
    @functools.wraps(fn)
91
    def wrapper(*args: Any) -> torch.Tensor:
92
        return fn(*args)[0]
93

94
    return wrapper
95

96

97
class _SequenceParallelPatternHelper:
98
99
    """Helper for sequence parallelism patterns."""

100
101
102
103
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
104
105
        device: str | None,
    ) -> None:
106
107
108
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
109
110
111
112
113
114
115
116
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
117
118
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
119
120
121

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
122
123
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
124
125
126


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
127
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
128
129
130
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)

131
    def get_inputs(self) -> list[torch.Tensor]:
132
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
133
134
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

135
        return [input, arg3_1]
136

137
    def register(self, pm_pass: PatternMatcherPass) -> None:
138
        def pattern(
139
            input: torch.Tensor,
140
            arg3_1: torch.Tensor,
141
        ) -> tuple[torch.Tensor, torch.Tensor]:
142
            all_reduce = self._all_reduce(input)
143
            rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
144

145
            return rmsnorm, all_reduce
146
147

        def replacement(
148
            input: torch.Tensor,
149
            arg3_1: torch.Tensor,
150
        ) -> tuple[torch.Tensor, torch.Tensor]:
151
            reduce_scatter = self._reduce_scatter(input)
152

153
154
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
            all_gather = self._all_gather(rmsnorm)
155
156
            return all_gather, reduce_scatter

157
158
159
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
160
161


162
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
163
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
164
165
166
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

167
    def get_inputs(self) -> list[torch.Tensor]:
168
169
170
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
171
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
172
173
174
175
176
177
178

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

179
    def register(self, pm_pass: PatternMatcherPass) -> None:
180
181
182
183
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
184
        ) -> tuple[torch.Tensor, torch.Tensor]:
185
            all_reduce = self._all_reduce(mm_1)
186
187
            rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
            return rmsnorm[0], rmsnorm[1]
188
189
190
191
192

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
193
        ) -> tuple[torch.Tensor, torch.Tensor]:
194
195
196
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # once the seqpar pattern with the previous rmsnorm is replaced
197
            reduce_scatter = self._reduce_scatter(mm_1)
198
199
200
201
202
203
            residual = residual[0 : reduce_scatter.size(0), ...]
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
            all_gather = self._all_gather(rmsnorm[0])
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, rmsnorm[1]
204

205
206
207
208
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
        pm.register_replacement(
209
210
211
212
213
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
214
        )
215
216
217


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
218
    def __init__(
219
220
221
        self,
        epsilon: float,
        dtype: torch.dtype,
222
        device: str | None,
223
    ) -> None:
224
225
226
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
227

228
    def get_inputs(self) -> list[torch.Tensor]:
229
230
231
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
232
        return [input, weight, scale]
233

234
    def register(self, pm_pass: PatternMatcherPass) -> None:
235
236
237
238
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
239
        ) -> tuple[torch.Tensor, torch.Tensor]:
240
            all_reduce = self._all_reduce(input)
241
242
243
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
244
245
246
247
248

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
249
        ) -> tuple[torch.Tensor, torch.Tensor]:
250
            reduce_scatter = self._reduce_scatter(input)
251
252
253
            rms = self.rmsnorm_matcher(reduce_scatter, weight)
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
254
255
256

            return all_gather, reduce_scatter

257
258
259
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
260
261
262


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
263
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
264
265
266
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
267

268
    def get_inputs(self) -> list[torch.Tensor]:
269
270
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
271
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
272
273
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

274
        return [residual, mm_1, rms_norm_weights, scale]
275

276
    def register(self, pm_pass: PatternMatcherPass) -> None:
277
278
279
280
281
282
283
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
284
285
            rms, residual_out = self.rmsnorm_matcher(
                all_reduce, rms_norm_weights, residual
286
            )
287
288
            quant, _ = self.quant_matcher(rms, scale)
            return quant, residual_out
289
290
291
292
293
294
295

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
296
297
298
299
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # add a temporary slice which will become a noop
            # once the seqpar pattern with the previous rmsnorm is replaced
300
            reduce_scatter = self._reduce_scatter(mm_1)
301
302
303
            residual = residual[0 : reduce_scatter.size(0), ...]
            rms, residual_out = self.rmsnorm_matcher(
                reduce_scatter, rms_norm_weights, residual
304
            )
305
306
307
308
309
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, residual_out
310

311
312
313
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
314

315
        pm.register_replacement(
316
317
318
319
320
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
321
        )
322
323


324
class SequenceParallelismPass(VllmPatternMatcherPass):
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358


    This pass splits up the residual tensor across TP ranks and hence divides its size.
    Because the pattern matcher starts at the end of the graph, the replacement
    contains a slice that temporarily conforms the input residual to the correct size.
    After all patterns have been matched, we use a NoOpEliminationPass to clean up
    what have now become no-op slices.

    Note that an older version of the pass did not need this as it operated only on
    custom rms_norm and fused_rms_norm_add custom ops which did not complain about
    mismatched shapes during replacement. So this approach has the same assumption that
    correctness is only maintained if all rms_norm operations are split across ranks.

    Correctness-wise, this is approach strictly better than before - before,
    the graph was incorrect semantically and shape-wise during the pass.
    With this approach there's only semantic incorrectness during the pass.
    Both approaches restore a correct graph once all patterns are matched.
359
    """
360

361
    @enable_fake_mode
362
    def __init__(self, config: VllmConfig) -> None:
363
364
        super().__init__(config)

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        # Get min_token_num threshold
        # Read min_token_num from config (calculated during config init)
        self.min_token_num = None
        if config.model_config is not None:
            pass_config = config.compilation_config.pass_config
            self.min_token_num = pass_config.sp_min_token_num

            if self.min_token_num is not None:
                # Take the min to avoid exceeding max_num_batched_tokens
                max_batched = config.scheduler_config.max_num_batched_tokens
                if max_batched is not None:
                    self.min_token_num = min(self.min_token_num, max_batched)
                logger.debug_once(
                    f"Sequence parallelism min token threshold: {self.min_token_num}",
                    scope="global",
                )

382
        # Used to clean up redundant views created temporarily
383
384
385
386
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

387
        self.patterns: PatternMatcherPass = PatternMatcherPass(
388
389
            pass_name="sequence_parallelism_pass"
        )
390

391
        for epsilon in [1e-5, 1e-6]:
392
393
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
394
                epsilon, self.model_dtype, self.device
395
            ).register(self.patterns)
396
            MiddleAllReduceRMSNormStaticFP8Pattern(
397
                epsilon, self.model_dtype, self.device
398
            ).register(self.patterns)
399
400

            # Normal RMSNorm patterns
401
402
403
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
404

405
406
407
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
408

409
        self.dump_patterns(config, self.patterns)
410

411
    def is_applicable_for_range(self, compile_range: Range) -> bool:
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        """
        Determines if sequence parallelism should be applied for the given
        compile range.

        SP is only beneficial for larger batch sizes where the communication
        overhead is amortized. For small batches, the overhead of splitting
        and gathering tensors across TP ranks outweighs the benefits.

        Returns False (SP disabled) when:
        - Using piecewise compilation with non-concrete or TP-indivisible sizes
        - min_token_num is None (SP disabled for this device/config)
        - The compile range starts below the minimum token threshold
        """
        # For piecewise compilation (not using inductor graph partition),
        # we need concrete sizes that are divisible by TP for correct splitting
427
        if (
428
429
            not self.compilation_config.use_inductor_graph_partition
            and self.compilation_config.splitting_ops
430
        ):
431
432
433
434
435
436
437
438
439
440
441
            tp_size = get_tensor_model_parallel_world_size()
            if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
                return False

        # min_token_num is None when SP is disabled for this device/config
        # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
        if self.min_token_num is None:
            return False

        # Only apply SP when batch size meets the minimum threshold
        return compile_range.start >= self.min_token_num
442

443
    @VllmInductorPass.time_and_log
444
    def __call__(self, graph: fx.Graph) -> None:
445
446
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
447
448
        # Clean up reshape nodes
        self.noop_cleanup(graph)