sp_plan.py 17.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM and The HuggingFace Team
# Type definitions in this module are adapted from HuggingFace diffusers library:
#   diffusers/src/diffusers/models/_modeling_parallel.py
"""Sequence Parallelism configuration and plan type definitions.

This module provides:
1. SequenceParallelConfig: Configuration for SP (ulysses_degree, ring_degree)
2. SequenceParallelInput/Output: Type definitions for _sp_plan declarations
3. Validation utilities for _sp_plan

A _sp_plan is a dictionary that specifies how to shard/gather tensors at
different points in a model's forward pass. This allows automatic handling
of sequence parallelism without modifying the model's forward() method.

NOTE: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP)
in diffusers. We use "Sequence Parallelism" to align with vLLM-Omni terminology.

Example:
    class MyTransformer(nn.Module):
        _sp_plan = {
            # Split inputs before model forward
            "": {
                "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
                "encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
            },
            # Split RoPE embeddings after pos_embed layer
            "pos_embed": {
                0: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True),
            },
            # Gather output after proj_out layer
            "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
        }
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import torch
    import torch.nn as nn


# =============================================================================
# Sequence Parallel Configuration
# =============================================================================


@dataclass
class SequenceParallelConfig:
    """Configuration for Sequence Parallelism using vLLM-Omni's parallel state.

    This class provides a unified interface for SP configuration that integrates
    with vLLM-Omni's existing SequenceParallelGroupCoordinator. Unlike diffusers'
    DeviceMesh-based approach (ContextParallelConfig), this uses the existing
    parallel state management.

    Note: This corresponds to `ContextParallelConfig` in diffusers library.

    Args:
        ulysses_degree: Number of devices for Ulysses (All-to-All) attention.
            Sequence is split across devices, with Q/K/V redistributed via
            All-to-All communication. Best for moderate sequences with good
            interconnect bandwidth.
        ring_degree: Number of devices for Ring attention. Sequence is split
            across devices, with K/V passed in a ring topology. Best for long
            sequences with limited memory/bandwidth.
        convert_to_fp32: Whether to convert output and LSE to float32 for
            numerical stability in ring attention.

    Note:
        ulysses_degree * ring_degree = sequence_parallel_size
        vLLM-Omni supports hybrid Ulysses-Ring attention (both > 1).
    """

    ulysses_degree: int = 1
    ring_degree: int = 1
    convert_to_fp32: bool = True

    # Internal state - populated by setup()
    _rank: int | None = None
    _world_size: int | None = None
    _device: torch.device | None = None

    def __post_init__(self) -> None:
        if self.ulysses_degree < 1 or self.ring_degree < 1:
            raise ValueError("`ulysses_degree` and `ring_degree` must be >= 1.")

        if self.ulysses_degree == 1 and self.ring_degree == 1:
            raise ValueError(
                "At least one of `ulysses_degree` or `ring_degree` must be > 1 to use sequence parallelism."
            )

    @property
    def sequence_parallel_size(self) -> int:
        """Total sequence parallel world size."""
        return self.ulysses_degree * self.ring_degree

    def get_world_size(self) -> int:
        """Get the sequence parallel world size from parallel state.

        Returns:
            The world size for sequence parallelism.

        Raises:
            RuntimeError: If parallel state is not initialized.
        """
        from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size

        return get_sequence_parallel_world_size()

    def get_rank(self) -> int:
        """Get the current rank in the sequence parallel group.

        Returns:
            The rank within the sequence parallel group.

        Raises:
            RuntimeError: If parallel state is not initialized.
        """
        from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_rank

        return get_sequence_parallel_rank()

    def get_ulysses_world_size(self) -> int:
        """Get the Ulysses parallel world size.

        Returns:
            The world size for Ulysses (All-to-All) parallelism.
        """
        from vllm_omni.diffusion.distributed.parallel_state import get_ulysses_parallel_world_size

        return get_ulysses_parallel_world_size()

    def get_ulysses_rank(self) -> int:
        """Get the current rank in the Ulysses parallel group.

        Returns:
            The rank within the Ulysses parallel group.
        """
        from vllm_omni.diffusion.distributed.parallel_state import get_ulysses_parallel_rank

        return get_ulysses_parallel_rank()

    def get_ring_world_size(self) -> int:
        """Get the Ring parallel world size.

        Returns:
            The world size for Ring attention parallelism.
        """
        from vllm_omni.diffusion.distributed.parallel_state import get_ring_parallel_world_size

        return get_ring_parallel_world_size()

    def get_ring_rank(self) -> int:
        """Get the current rank in the Ring parallel group.

        Returns:
            The rank within the Ring parallel group.
        """
        from vllm_omni.diffusion.distributed.parallel_state import get_ring_parallel_rank

        return get_ring_parallel_rank()

    def setup(self, rank: int, world_size: int, device: torch.device) -> None:
        """Initialize the config with runtime parallel state.

        This is called automatically when sequence parallelism is enabled.

        Args:
            rank: The global rank of this process.
            world_size: Total world size.
            device: The device for this rank.
        """
        self._rank = rank
        self._world_size = world_size
        self._device = device

        expected_sp_size = self.ulysses_degree * self.ring_degree
        actual_sp_size = self.get_world_size()

        if expected_sp_size != actual_sp_size:
            raise ValueError(
                f"Configuration mismatch: ulysses_degree ({self.ulysses_degree}) * "
                f"ring_degree ({self.ring_degree}) = {expected_sp_size}, but "
                f"actual sequence parallel world size is {actual_sp_size}."
            )

    def is_initialized(self) -> bool:
        """Check if the config has been initialized with runtime state.

        Returns:
            True if setup() has been called, False otherwise.
        """
        return self._rank is not None


# =============================================================================
# Sequence Parallel Plan Type Definitions
# =============================================================================


@dataclass(frozen=True)
class SequenceParallelInput:
    """Configuration for splitting an input tensor across sequence parallel ranks.

    This specifies how to shard a tensor in the pre-forward or post-forward hook
    of a layer. The tensor will be split along the specified dimension.

    Note: This corresponds to `ContextParallelInput` in diffusers library.

    Args:
        split_dim: The dimension along which to split the tensor.
        expected_dims: Expected number of dimensions. If provided, validates that
            the tensor has this many dimensions before splitting. If the tensor
            has a different number of dimensions, splitting is skipped with a warning.
        split_output: If True, split the output of the layer instead of the input.
            This is useful for layers whose outputs should be split after preprocessing
            (e.g., RoPE embeddings).
        auto_pad: If True, automatically pad the tensor if its size along split_dim
            is not divisible by world_size. Creates an attention mask to indicate
            valid vs padding positions. The mask is stored in ForwardContext.
            Note: Ring attention does not support attention mask, so auto_pad
            should only be used with Ulysses SP.

    Example:
        # Split hidden_states along sequence dimension (dim 1)
        SequenceParallelInput(split_dim=1, expected_dims=3)

        # Split RoPE output along sequence dimension (dim 0)
        SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True)

        # Split with auto-padding for variable-length sequences
        SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True)
    """

    split_dim: int
    expected_dims: int | None = None
    split_output: bool = False
    auto_pad: bool = False

    def __repr__(self) -> str:
        return (
            f"SequenceParallelInput(split_dim={self.split_dim}, "
            f"expected_dims={self.expected_dims}, split_output={self.split_output}, "
            f"auto_pad={self.auto_pad})"
        )


@dataclass(frozen=True)
class SequenceParallelOutput:
    """Configuration for gathering an output tensor across sequence parallel ranks.

    This specifies how to gather a tensor in the post-forward hook of a layer.
    The tensor will be gathered along the specified dimension from all ranks.

    Note: This corresponds to `ContextParallelOutput` in diffusers library.

    Args:
        gather_dim: The dimension along which to gather the tensor.
        expected_dims: Expected number of dimensions. If provided, validates that
            the tensor has this many dimensions before gathering.

    Example:
        # Gather output along sequence dimension (dim 1)
        SequenceParallelOutput(gather_dim=1, expected_dims=3)
    """

    gather_dim: int
    expected_dims: int | None = None

    def __repr__(self) -> str:
        return f"SequenceParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"


@dataclass(frozen=True)
class SequenceParallelPartialInput:
    """Configuration for partially splitting a tensor (e.g., split image part, keep text part).

    This is designed for models like LongCat/Qwen where RoPE embeddings need special handling:
    - Text portion: kept full across all ranks (for joint attention)
    - Image portion: split across ranks

    The tensor is assumed to be concatenated as [text_part, image_part] along split_dim.

    Note: This is an extension beyond diffusers' standard ContextParallelInput,
    designed for vLLM-Omni's dual-stream attention models.

    Args:
        split_dim: The dimension along which to split the image portion.
        text_len_source: How to determine text length:
            - str: Name of a forward parameter that contains text length
            - int: Fixed text length value
        expected_dims: Expected number of dimensions for validation.
        split_output: If True, split the output instead of input.

    Example:
        # Split RoPE: text portion (from txt_ids.shape[0]) kept full, image portion split
        SequenceParallelPartialInput(
            split_dim=0,
            text_len_source="txt_ids",  # Get text length from txt_ids.shape[0]
            expected_dims=2,
            split_output=True,
        )

        # Or with fixed text length
        SequenceParallelPartialInput(
            split_dim=0,
            text_len_source=512,  # Fixed text length
            expected_dims=2,
            split_output=True,
        )
    """

    split_dim: int
    text_len_source: str | int
    expected_dims: int | None = None
    split_output: bool = False

    def __repr__(self) -> str:
        return (
            f"SequenceParallelPartialInput(split_dim={self.split_dim}, "
            f"text_len_source={self.text_len_source!r}, expected_dims={self.expected_dims}, "
            f"split_output={self.split_output})"
        )


# =============================================================================
# Type Aliases for _sp_plan Structure
# =============================================================================

# Any input config type
AnySequenceParallelInput = SequenceParallelInput | SequenceParallelPartialInput

# Input specification: maps parameter names (str) or output indices (int) to split config
SequenceParallelInputType = dict[
    str | int,
    AnySequenceParallelInput | list[AnySequenceParallelInput] | tuple[AnySequenceParallelInput, ...],
]

# Output specification: single or multiple gather configs
SequenceParallelOutputType = SequenceParallelOutput | list[SequenceParallelOutput] | tuple[SequenceParallelOutput, ...]

# Full model plan: maps module names to input/output specifications
# - Key "" refers to the model itself (root level)
# - Key "module_name" refers to a submodule
# - Key "module_name.*" refers to all children of a ModuleList
#
# Example of a complete _sp_plan:
#
#     _sp_plan = {
#         # Root level: split model inputs before any submodule
#         "": {
#             "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
#         },
#         # Submodule: split outputs of pos_embed (RoPE) layer
#         "pos_embed": {
#             0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),  # cos
#             1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),  # sin
#         },
#         # Submodule: gather outputs of proj_out layer
#         "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
#     }
#
SequenceParallelModelPlan = dict[str, SequenceParallelInputType | SequenceParallelOutputType]


# =============================================================================
# Validation Utilities
# =============================================================================


def _is_valid_input_config(value: object) -> bool:
    """Check if a value is a valid input configuration type."""
    return isinstance(value, (SequenceParallelInput, SequenceParallelPartialInput))


def _is_valid_input_config_list(value: object) -> bool:
    """Check if a value is a list/tuple of valid input configurations."""
    if not isinstance(value, (list, tuple)):
        return False
    return all(_is_valid_input_config(x) for x in value)


def validate_sp_plan(plan: SequenceParallelModelPlan) -> None:
    """Validate a _sp_plan dictionary for correctness.

    Args:
        plan: The _sp_plan dictionary to validate.

    Raises:
        ValueError: If the plan is invalid.
    """
    if not isinstance(plan, dict):
        raise ValueError(f"_sp_plan must be a dict, got {type(plan).__name__}")

    for module_id, module_plan in plan.items():
        if not isinstance(module_id, str):
            raise ValueError(f"_sp_plan keys must be strings, got {type(module_id).__name__}")

        # Check if it's an output specification (SequenceParallelOutput or list/tuple thereof)
        if isinstance(module_plan, SequenceParallelOutput):
            continue
        if isinstance(module_plan, (list, tuple)):
            if all(isinstance(x, SequenceParallelOutput) for x in module_plan):
                continue
            if _is_valid_input_config_list(module_plan):
                # List of inputs for a specific parameter (when output is tuple)
                continue

        # Otherwise, should be an input specification dict
        if isinstance(module_plan, dict):
            for key, value in module_plan.items():
                if not isinstance(key, (str, int)):
                    raise ValueError(
                        f"Input spec keys must be str or int, got {type(key).__name__} for module '{module_id}'"
                    )
                if isinstance(key, int) and not _is_valid_input_config(value):
                    raise ValueError(
                        f"Integer keys (output indices) must map to SequenceParallelInput/PartialInput, "
                        f"got {type(value).__name__} for module '{module_id}'[{key}]"
                    )
                if _is_valid_input_config(value):
                    if isinstance(key, int) and not value.split_output:
                        raise ValueError(
                            f"Integer keys (output indices) require split_output=True, "
                            f"got split_output=False for module '{module_id}'[{key}]"
                        )
                elif _is_valid_input_config_list(value):
                    pass  # Valid list of input configs
                else:
                    raise ValueError(
                        f"Input spec values must be SequenceParallelInput/PartialInput or list thereof, "
                        f"got {type(value).__name__} for module '{module_id}'['{key}']"
                    )
        else:
            raise ValueError(
                f"_sp_plan values must be dict (input spec) or SequenceParallelOutput, "
                f"got {type(module_plan).__name__} for module '{module_id}'"
            )


def get_sp_plan_from_model(model: nn.Module) -> SequenceParallelModelPlan | None:
    """Get the _sp_plan from a model if it exists.

    Args:
        model: The model to get the plan from.

    Returns:
        The _sp_plan dictionary, or None if not defined.
    """
    plan = getattr(model, "_sp_plan", None)
    if plan is not None:
        validate_sp_plan(plan)
    return plan