sequence_parallelism.py 17.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
6
from collections.abc import Callable, Sequence
from typing import Any
7

8
9
10
11
12
13
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
14
from vllm.config.utils import Range
15
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
16
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
17
from vllm.logger import init_logger
18
19
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
21
from vllm.platforms import current_platform
22

23
24
25
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
26
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
27
28
29

logger = init_logger(__name__)

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
    90: 8192,  # H100: only for models with hidden_size >= 8192
}

# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
    90: 8,  # 8MB per GPU for H100
}


def get_sequence_parallelism_threshold(
    hidden_size: int,
    tp_size: int,
    element_size: int,
) -> int | None:
    """
    Calculate the minimum token threshold for applying sequence parallelism.

    Returns None if sequence parallelism should not be applied based on model size.

    Branching logic based on device capability:
    - Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
    - If not, returns None (SP disabled for small models on this device)
    - If yes, calculates threshold based on per-GPU size

    Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
             (hidden_size * element_size)
    """
    from vllm.platforms import current_platform

    if not current_platform.is_cuda():
        return None

    capability = current_platform.get_device_capability()
    if capability is None:
        return None
    device_capability = capability.to_int()

    # Check if device has configured thresholds
    min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
    min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)

    if min_hidden_size is None or min_per_gpu_size_mb is None:
        return None

    # Only apply sequence parallelism for models meeting the size threshold
    if hidden_size < min_hidden_size:
        return None

    MiB = 1024 * 1024
    min_size = min_per_gpu_size_mb * MiB * tp_size
    return int(min_size // (hidden_size * element_size))

87

88
89
90
def get_first_out_wrapper(
    fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
91
    @functools.wraps(fn)
92
    def wrapper(*args: Any) -> torch.Tensor:
93
        return fn(*args)[0]
94

95
    return wrapper
96

97

98
class _SequenceParallelPatternHelper:
99
100
    """Helper for sequence parallelism patterns."""

101
102
103
104
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
105
106
        device: str | None,
    ) -> None:
107
108
109
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
110
111
112
113
114
115
116
117
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
118
119
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
120
121
122

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
123
124
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
125
126
127


class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
128
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
129
130
131
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)

132
    def get_inputs(self) -> list[torch.Tensor]:
133
        input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
134
135
        arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)

136
        return [input, arg3_1]
137

138
    def register(self, pm_pass: PatternMatcherPass) -> None:
139
        def pattern(
140
            input: torch.Tensor,
141
            arg3_1: torch.Tensor,
142
        ) -> tuple[torch.Tensor, torch.Tensor]:
143
            all_reduce = self._all_reduce(input)
144
            rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
145

146
            return rmsnorm, all_reduce
147
148

        def replacement(
149
            input: torch.Tensor,
150
            arg3_1: torch.Tensor,
151
        ) -> tuple[torch.Tensor, torch.Tensor]:
152
            reduce_scatter = self._reduce_scatter(input)
153

154
155
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
            all_gather = self._all_gather(rmsnorm)
156
157
            return all_gather, reduce_scatter

158
159
160
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
161
162


163
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
164
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
165
166
167
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

168
    def get_inputs(self) -> list[torch.Tensor]:
169
170
171
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
172
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
173
174
175
176
177
178
179

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

180
    def register(self, pm_pass: PatternMatcherPass) -> None:
181
182
183
184
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
185
        ) -> tuple[torch.Tensor, torch.Tensor]:
186
            all_reduce = self._all_reduce(mm_1)
187
188
            rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
            return rmsnorm[0], rmsnorm[1]
189
190
191
192
193

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
194
        ) -> tuple[torch.Tensor, torch.Tensor]:
195
196
197
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # once the seqpar pattern with the previous rmsnorm is replaced
198
            reduce_scatter = self._reduce_scatter(mm_1)
199
200
201
202
203
204
            residual = residual[0 : reduce_scatter.size(0), ...]
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
            all_gather = self._all_gather(rmsnorm[0])
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, rmsnorm[1]
205

206
207
208
209
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
        pm.register_replacement(
210
211
212
213
214
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
215
        )
216
217
218
219
220
221


FP8_DTYPE = current_platform.fp8_dtype()


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
222
    def __init__(
223
224
225
        self,
        epsilon: float,
        dtype: torch.dtype,
226
        device: str | None,
227
    ) -> None:
228
229
230
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
231

232
    def get_inputs(self) -> list[torch.Tensor]:
233
234
235
        input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
        weight = torch.empty([4], device=self.device, dtype=self.dtype)
        scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
236
        return [input, weight, scale]
237

238
    def register(self, pm_pass: PatternMatcherPass) -> None:
239
240
241
242
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
243
        ) -> tuple[torch.Tensor, torch.Tensor]:
244
            all_reduce = self._all_reduce(input)
245
246
247
            rms = self.rmsnorm_matcher(all_reduce, weight)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
248
249
250
251
252

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
253
        ) -> tuple[torch.Tensor, torch.Tensor]:
254
            reduce_scatter = self._reduce_scatter(input)
255
256
257
            rms = self.rmsnorm_matcher(reduce_scatter, weight)
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
258
259
260

            return all_gather, reduce_scatter

261
262
263
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
264
265
266


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
267
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
268
269
270
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
271

272
    def get_inputs(self) -> list[torch.Tensor]:
273
274
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
275
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
276
277
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

278
        return [residual, mm_1, rms_norm_weights, scale]
279

280
    def register(self, pm_pass: PatternMatcherPass) -> None:
281
282
283
284
285
286
287
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
288
289
            rms, residual_out = self.rmsnorm_matcher(
                all_reduce, rms_norm_weights, residual
290
            )
291
292
            quant, _ = self.quant_matcher(rms, scale)
            return quant, residual_out
293
294
295
296
297
298
299

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
300
301
302
303
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # add a temporary slice which will become a noop
            # once the seqpar pattern with the previous rmsnorm is replaced
304
            reduce_scatter = self._reduce_scatter(mm_1)
305
306
307
            residual = residual[0 : reduce_scatter.size(0), ...]
            rms, residual_out = self.rmsnorm_matcher(
                reduce_scatter, rms_norm_weights, residual
308
            )
309
310
311
312
313
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, residual_out
314

315
316
317
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
318

319
        pm.register_replacement(
320
321
322
323
324
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
325
        )
326
327


328
class SequenceParallelismPass(VllmPatternMatcherPass):
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362


    This pass splits up the residual tensor across TP ranks and hence divides its size.
    Because the pattern matcher starts at the end of the graph, the replacement
    contains a slice that temporarily conforms the input residual to the correct size.
    After all patterns have been matched, we use a NoOpEliminationPass to clean up
    what have now become no-op slices.

    Note that an older version of the pass did not need this as it operated only on
    custom rms_norm and fused_rms_norm_add custom ops which did not complain about
    mismatched shapes during replacement. So this approach has the same assumption that
    correctness is only maintained if all rms_norm operations are split across ranks.

    Correctness-wise, this is approach strictly better than before - before,
    the graph was incorrect semantically and shape-wise during the pass.
    With this approach there's only semantic incorrectness during the pass.
    Both approaches restore a correct graph once all patterns are matched.
363
    """
364

365
    @enable_fake_mode
366
    def __init__(self, config: VllmConfig) -> None:
367
368
        super().__init__(config)

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        # Get min_token_num threshold
        # Read min_token_num from config (calculated during config init)
        self.min_token_num = None
        if config.model_config is not None:
            pass_config = config.compilation_config.pass_config
            self.min_token_num = pass_config.sp_min_token_num

            if self.min_token_num is not None:
                # Take the min to avoid exceeding max_num_batched_tokens
                max_batched = config.scheduler_config.max_num_batched_tokens
                if max_batched is not None:
                    self.min_token_num = min(self.min_token_num, max_batched)
                logger.debug_once(
                    f"Sequence parallelism min token threshold: {self.min_token_num}",
                    scope="global",
                )

386
        # Used to clean up redundant views created temporarily
387
388
389
390
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

391
        self.patterns: PatternMatcherPass = PatternMatcherPass(
392
393
            pass_name="sequence_parallelism_pass"
        )
394

395
        for epsilon in [1e-5, 1e-6]:
396
397
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
398
                epsilon, self.model_dtype, self.device
399
            ).register(self.patterns)
400
            MiddleAllReduceRMSNormStaticFP8Pattern(
401
                epsilon, self.model_dtype, self.device
402
            ).register(self.patterns)
403
404

            # Normal RMSNorm patterns
405
406
407
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
408

409
410
411
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
412

413
        self.dump_patterns(config, self.patterns)
414

415
    def is_applicable_for_range(self, compile_range: Range) -> bool:
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        """
        Determines if sequence parallelism should be applied for the given
        compile range.

        SP is only beneficial for larger batch sizes where the communication
        overhead is amortized. For small batches, the overhead of splitting
        and gathering tensors across TP ranks outweighs the benefits.

        Returns False (SP disabled) when:
        - Using piecewise compilation with non-concrete or TP-indivisible sizes
        - min_token_num is None (SP disabled for this device/config)
        - The compile range starts below the minimum token threshold
        """
        # For piecewise compilation (not using inductor graph partition),
        # we need concrete sizes that are divisible by TP for correct splitting
431
        if (
432
433
            not self.compilation_config.use_inductor_graph_partition
            and self.compilation_config.splitting_ops
434
        ):
435
436
437
438
439
440
441
442
443
444
445
            tp_size = get_tensor_model_parallel_world_size()
            if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
                return False

        # min_token_num is None when SP is disabled for this device/config
        # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
        if self.min_token_num is None:
            return False

        # Only apply SP when batch size meets the minimum threshold
        return compile_range.start >= self.min_token_num
446

447
    @VllmInductorPass.time_and_log
448
    def __call__(self, graph: fx.Graph) -> None:
449
450
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
451
452
        # Clean up reshape nodes
        self.noop_cleanup(graph)