sequence_parallelism.py 16.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
6
from collections.abc import Callable, Sequence
from typing import Any
7

8
9
10
11
12
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass

13
import vllm.ir.ops
14
from vllm.config import VllmConfig
15
from vllm.config.utils import Range
16
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
17
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
18
from vllm.logger import init_logger
19
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
22

23
24
25
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
26
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
27
28
29

logger = init_logger(__name__)

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
    90: 8192,  # H100: only for models with hidden_size >= 8192
}

# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
    90: 8,  # 8MB per GPU for H100
}


def get_sequence_parallelism_threshold(
    hidden_size: int,
    tp_size: int,
    element_size: int,
) -> int | None:
    """
    Calculate the minimum token threshold for applying sequence parallelism.

    Returns None if sequence parallelism should not be applied based on model size.

    Branching logic based on device capability:
    - Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
    - If not, returns None (SP disabled for small models on this device)
    - If yes, calculates threshold based on per-GPU size

    Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
             (hidden_size * element_size)
    """
    from vllm.platforms import current_platform

    if not current_platform.is_cuda():
        return None

    capability = current_platform.get_device_capability()
    if capability is None:
        return None
    device_capability = capability.to_int()

    # Check if device has configured thresholds
    min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
    min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)

    if min_hidden_size is None or min_per_gpu_size_mb is None:
        return None

    # Only apply sequence parallelism for models meeting the size threshold
    if hidden_size < min_hidden_size:
        return None

    MiB = 1024 * 1024
    min_size = min_per_gpu_size_mb * MiB * tp_size
    return int(min_size // (hidden_size * element_size))

87

88
89
90
def get_first_out_wrapper(
    fn: Callable[..., Sequence[torch.Tensor]],
) -> Callable[..., torch.Tensor]:
91
    @functools.wraps(fn)
92
    def wrapper(*args: Any) -> torch.Tensor:
93
        return fn(*args)[0]
94

95
    return wrapper
96

97

98
class _SequenceParallelPatternHelper:
99
100
    """Helper for sequence parallelism patterns."""

101
102
103
104
    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
105
106
        device: str | None,
    ) -> None:
107
108
109
        self.epsilon = epsilon
        self.dtype = dtype
        self.device = device
110
111
112
113
114
115
116
117
        self.tp_group = get_tp_group()
        self.tp_size = get_tensor_model_parallel_world_size()

    def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return tensor_model_parallel_all_reduce(x)

    def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.reduce_scatter.default(
118
119
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
120
121
122

    def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
        return torch.ops.vllm.all_gather.default(
123
124
            x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
        )
125

126
127
128
129
130
131
    def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)

    def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)

132
133

class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
134
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
135
136
        super().__init__(epsilon, dtype, device)

137
    def get_inputs(self) -> list[torch.Tensor]:
138
139
        # input, weight
        return [self.empty([1, 8, 4]), self.empty([4])]
140

141
    def register(self, pm_pass: PatternMatcherPass) -> None:
142
        def pattern(
143
            input: torch.Tensor,
144
            weight: torch.Tensor,
145
        ) -> tuple[torch.Tensor, torch.Tensor]:
146
            all_reduce = self._all_reduce(input)
147
            rmsnorm = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
148

149
            return rmsnorm, all_reduce
150
151

        def replacement(
152
            input: torch.Tensor,
153
            weight: torch.Tensor,
154
        ) -> tuple[torch.Tensor, torch.Tensor]:
155
            reduce_scatter = self._reduce_scatter(input)
156

157
            rmsnorm = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
158
            all_gather = self._all_gather(rmsnorm)
159
160
            return all_gather, reduce_scatter

161
162
163
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
164
165


166
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
167
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
168
169
170
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)

171
    def get_inputs(self) -> list[torch.Tensor]:
172
173
174
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)

        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
175
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
176
177
178
179
180
181
182

        return [
            residual,
            mm_1,
            rms_norm_weights,
        ]

183
    def register(self, pm_pass: PatternMatcherPass) -> None:
184
185
186
187
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
188
        ) -> tuple[torch.Tensor, torch.Tensor]:
189
            all_reduce = self._all_reduce(mm_1)
190
191
            rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
            return rmsnorm[0], rmsnorm[1]
192
193
194
195
196

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
197
        ) -> tuple[torch.Tensor, torch.Tensor]:
198
199
200
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # once the seqpar pattern with the previous rmsnorm is replaced
201
            reduce_scatter = self._reduce_scatter(mm_1)
202
203
204
205
206
207
            residual = residual[0 : reduce_scatter.size(0), ...]
            rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
            all_gather = self._all_gather(rmsnorm[0])
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, rmsnorm[1]
208

209
210
211
212
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
        pm.register_replacement(
213
214
215
216
217
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
218
        )
219
220
221


class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
222
    def __init__(
223
224
225
        self,
        epsilon: float,
        dtype: torch.dtype,
226
        device: str | None,
227
    ) -> None:
228
229
        super().__init__(epsilon, dtype, device)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
230

231
    def get_inputs(self) -> list[torch.Tensor]:
232
233
        # input, weight, scale
        return [self.empty([1, 8, 4]), self.empty([4]), self.empty_f32([1, 1])]
234

235
    def register(self, pm_pass: PatternMatcherPass) -> None:
236
237
238
239
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
240
        ) -> tuple[torch.Tensor, torch.Tensor]:
241
            all_reduce = self._all_reduce(input)
242
            rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
243
244
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce
245
246
247
248
249

        def replacement(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
250
        ) -> tuple[torch.Tensor, torch.Tensor]:
251
            reduce_scatter = self._reduce_scatter(input)
252
            rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
253
254
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
255
256
257

            return all_gather, reduce_scatter

258
259
260
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
261
262
263


class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
264
    def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
265
266
267
        super().__init__(epsilon, dtype, device)
        self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
268

269
    def get_inputs(self) -> list[torch.Tensor]:
270
271
        mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
        residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
272
        rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
273
274
        scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)

275
        return [residual, mm_1, rms_norm_weights, scale]
276

277
    def register(self, pm_pass: PatternMatcherPass) -> None:
278
279
280
281
282
283
284
        def pattern(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = self._all_reduce(mm_1)
285
286
            rms, residual_out = self.rmsnorm_matcher(
                all_reduce, rms_norm_weights, residual
287
            )
288
289
            quant, _ = self.quant_matcher(rms, scale)
            return quant, residual_out
290
291
292
293
294
295
296

        def replacement(
            residual: torch.Tensor,
            mm_1: torch.Tensor,
            rms_norm_weights: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
297
298
299
300
            # pattern matcher replaces from top-to-bottom,
            # so residual is still the full size here.
            # add a temporary slice which will become a noop
            # once the seqpar pattern with the previous rmsnorm is replaced
301
            reduce_scatter = self._reduce_scatter(mm_1)
302
303
304
            residual = residual[0 : reduce_scatter.size(0), ...]
            rms, residual_out = self.rmsnorm_matcher(
                reduce_scatter, rms_norm_weights, residual
305
            )
306
307
308
309
310
            quant, _ = self.quant_matcher(rms, scale)
            all_gather = self._all_gather(quant)
            # shape of residual changes but that's fine,
            # next node is already slicing it, now becomes a noop
            return all_gather, residual_out
311

312
313
314
        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )
315

316
        pm.register_replacement(
317
318
319
320
321
            get_first_out_wrapper(pattern),
            get_first_out_wrapper(replacement),
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
322
        )
323
324


325
class SequenceParallelismPass(VllmPatternMatcherPass):
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    """
    This pass enables sequence parallelism for models.
    It identifies patterns where an AllReduce operation is followed by
    an RMSNorm (or RMSNorm and then Quantization) operation.
    These patterns are replaced with a ReduceScatter operation, followed by
    a local RMSNorm/Quantization, and then an AllGather operation.

    The general transformation is:
    Input -> AllReduce -> RMSNorm -> Output
    becomes
    Input -> ReduceScatter -> RMSNorm -> AllGather -> Output

    While this pass itself does not directly yield performance improvements,
    it lays the groundwork for subsequent fusion passes, such as
    GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
    significantly reduce communication overhead and improve overall model
    performance.
343

344
345
346
347
348
349
350
351
352
353
354
355
    This pass is only supported when compiling the whole graph (fullgraph
    mode, i.e. using Inductor graph partition or empty splitting_ops).
    Piecewise compilation is not supported because the residual tensor
    gets split across TP ranks, causing size mismatches at subgraph
    boundaries.

    This pass splits up the residual tensor across TP ranks and hence
    divides its size. Because the pattern matcher starts at the end of
    the graph, the replacement contains a slice that temporarily conforms
    the input residual to the correct size. After all patterns have been
    matched, we use a NoOpEliminationPass to clean up what have now
    become no-op slices.
356
    """
357

358
    @enable_fake_mode
359
    def __init__(self, config: VllmConfig) -> None:
360
361
        super().__init__(config)

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        # Get min_token_num threshold
        # Read min_token_num from config (calculated during config init)
        self.min_token_num = None
        if config.model_config is not None:
            pass_config = config.compilation_config.pass_config
            self.min_token_num = pass_config.sp_min_token_num

            if self.min_token_num is not None:
                # Take the min to avoid exceeding max_num_batched_tokens
                max_batched = config.scheduler_config.max_num_batched_tokens
                if max_batched is not None:
                    self.min_token_num = min(self.min_token_num, max_batched)
                logger.debug_once(
                    f"Sequence parallelism min token threshold: {self.min_token_num}",
                    scope="global",
                )

379
        # Used to clean up redundant views created temporarily
380
381
382
383
        # to circumvent residual shape change issues
        self.noop_cleanup = NoOpEliminationPass(config)
        self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"

384
        self.patterns: PatternMatcherPass = PatternMatcherPass(
385
386
            pass_name="sequence_parallelism_pass"
        )
387

388
        for epsilon in [1e-5, 1e-6]:
389
390
            # RMSNorm + Static FP8 quantization patterns
            FirstAllReduceRMSNormStaticFP8Pattern(
391
                epsilon, self.model_dtype, self.device
392
            ).register(self.patterns)
393
            MiddleAllReduceRMSNormStaticFP8Pattern(
394
                epsilon, self.model_dtype, self.device
395
            ).register(self.patterns)
396
397

            # Normal RMSNorm patterns
398
399
400
            FirstAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
401

402
403
404
            MiddleAllReduceRMSNormPattern(
                epsilon, self.model_dtype, self.device
            ).register(self.patterns)
405

406
        self.dump_patterns(config, self.patterns)
407

408
    def is_applicable_for_range(self, compile_range: Range) -> bool:
409
410
411
412
413
414
415
416
417
418
419
420
        """
        Determines if sequence parallelism should be applied for the given
        compile range.

        SP is only beneficial for larger batch sizes where the communication
        overhead is amortized. For small batches, the overhead of splitting
        and gathering tensors across TP ranks outweighs the benefits.

        Returns False (SP disabled) when:
        - min_token_num is None (SP disabled for this device/config)
        - The compile range starts below the minimum token threshold
        """
421
422
423
424
        assert (
            self.compilation_config.use_inductor_graph_partition
            or not self.compilation_config.splitting_ops
        ), "SequenceParallelismPass requires full-graph compilation"
425
426
427
428
429
430
431
432

        # min_token_num is None when SP is disabled for this device/config
        # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
        if self.min_token_num is None:
            return False

        # Only apply SP when batch size meets the minimum threshold
        return compile_range.start >= self.min_token_num
433

434
    @VllmInductorPass.time_and_log
435
    def __call__(self, graph: fx.Graph) -> None:
436
437
        self.matched_count = self.patterns.apply(graph)
        logger.debug("Replaced %s patterns", self.matched_count)
438
439
        # Clean up reshape nodes
        self.noop_cleanup(graph)