sp_sharding.py 7.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project and The HuggingFace Team
"""Sequence Parallelism sharding utilities.

This module provides low-level sharding and gathering functions for
Sequence Parallelism. These can be used directly in model forward methods
for semi-intrusive SP support, or internally by the SP hooks.
"""

from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass, field

import torch
from vllm.logger import init_logger

from vllm_omni.diffusion.distributed.parallel_state import (
    get_sequence_parallel_rank,
    get_sequence_parallel_world_size,
    get_sp_group,
)

logger = init_logger(__name__)


def sp_shard(
    tensor: torch.Tensor,
    dim: int,
    validate: bool = True,
) -> torch.Tensor:
    """Shard a tensor along the specified dimension for sequence parallelism.

    The tensor is split into world_size chunks along dim, and this rank
    receives its corresponding chunk.

    Args:
        tensor: The tensor to shard.
        dim: The dimension along which to split.
        validate: If True, validate that the tensor size is divisible by world_size.

    Returns:
        The shard for this rank.

    Raises:
        ValueError: If validate=True and tensor size is not divisible by world_size.

    Example:
        # In model forward:
        hidden_states = sp_shard(hidden_states, dim=1)
    """
    world_size = get_sequence_parallel_world_size()

    if world_size == 1:
        return tensor

    rank = get_sequence_parallel_rank()
    size = tensor.size(dim)

    if validate and size % world_size != 0:
        raise ValueError(
            f"Tensor size along dim {dim} ({size}) must be divisible by "
            f"world_size ({world_size}) for sequence parallel sharding."
        )

    return tensor.chunk(world_size, dim=dim)[rank]


def sp_gather(
    tensor: torch.Tensor,
    dim: int,
    validate: bool = True,
) -> torch.Tensor:
    """Gather a tensor along the specified dimension from all sequence parallel ranks.

    The sharded tensors from all ranks are concatenated along dim.

    Args:
        tensor: The local shard to gather.
        dim: The dimension along which to gather.
        validate: If True, validate tensor consistency (currently unused).

    Returns:
        The full tensor gathered from all ranks.

    Example:
        # At end of model forward:
        output = sp_gather(output, dim=1)
    """
    world_size = get_sequence_parallel_world_size()

    if world_size == 1:
        return tensor

    sp_group = get_sp_group()
    return sp_group.all_gather(tensor, dim=dim)


def sp_shard_with_padding(
    tensor: torch.Tensor,
    dim: int,
    pad_value: float = 0.0,
) -> tuple[torch.Tensor, int]:
    """Shard a tensor with automatic padding if not divisible by world_size.

    This is useful for variable-length sequences where padding may be needed.

    Args:
        tensor: The tensor to shard.
        dim: The dimension along which to split.
        pad_value: Value to use for padding.

    Returns:
        Tuple of (sharded_tensor, padding_size). The padding_size indicates
        how much padding was added to the original tensor before sharding.

    Example:
        sharded, pad_size = sp_shard_with_padding(hidden_states, dim=1)
        # ... process ...
        output = sp_gather(output, dim=1)
        if pad_size > 0:
            output = output[..., :-pad_size]  # Remove padding
    """
    world_size = get_sequence_parallel_world_size()

    if world_size == 1:
        return tensor, 0

    size = tensor.size(dim)
    remainder = size % world_size

    if remainder == 0:
        return sp_shard(tensor, dim, validate=False), 0

    # Pad to make divisible
    pad_size = world_size - remainder
    pad_shape = list(tensor.shape)
    pad_shape[dim] = pad_size
    padding = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device)
    tensor = torch.cat([tensor, padding], dim=dim)

    return sp_shard(tensor, dim, validate=False), pad_size


# NOTE: This class is a vLLM-Omni extension for
# debugging intrusive SP implementations.
# Purpose:
# - Help developers detect bugs when implementing intrusive SP
# - Verify that every sharded tensor is properly gathered
# - Warn about common mistakes (double shard, gather without shard)
#
# When to use:
# - During development/debugging of intrusive SP code
# - In tests to verify shard/gather correctness
@dataclass
class ShardingValidator:
    """Validator for tracking and verifying sharding operations.

    This class helps ensure that sharding and gathering operations are
    correctly paired in model forward passes. It tracks which tensors
    have been sharded and verifies that they are properly gathered.

    Usage:
        validator = ShardingValidator()
        with validator.track():
            hidden_states = validator.shard(hidden_states, "hidden_states", dim=1)
            # ... model computation ...
            output = validator.gather(output, "hidden_states", dim=1)
        validator.validate()  # Raises if any shard was not gathered

    Attributes:
        _sharded: Set of tensor names that have been sharded.
        _gathered: Set of tensor names that have been gathered.
        _enabled: Whether tracking is currently enabled.
    """

    _sharded: set[str] = field(default_factory=set)
    _gathered: set[str] = field(default_factory=set)
    _enabled: bool = False

    def reset(self) -> None:
        """Reset the validator state for a new forward pass."""
        self._sharded.clear()
        self._gathered.clear()

    @contextmanager
    def track(self):
        """Context manager to enable tracking for a forward pass."""
        self._enabled = True
        self.reset()
        try:
            yield
        finally:
            self._enabled = False

    def shard(
        self,
        tensor: torch.Tensor,
        name: str,
        dim: int,
        validate_divisible: bool = True,
    ) -> torch.Tensor:
        """Shard a tensor and track the operation.

        Args:
            tensor: The tensor to shard.
            name: A name to identify this tensor for validation.
            dim: The dimension along which to split.
            validate_divisible: If True, validate divisibility.

        Returns:
            The sharded tensor.
        """
        if self._enabled:
            if name in self._sharded:
                logger.warning(f"Tensor '{name}' sharded multiple times")
            self._sharded.add(name)

        return sp_shard(tensor, dim, validate=validate_divisible)

    def gather(
        self,
        tensor: torch.Tensor,
        name: str,
        dim: int,
    ) -> torch.Tensor:
        """Gather a tensor and track the operation.

        Args:
            tensor: The local shard to gather.
            name: The name used when sharding (for validation).
            dim: The dimension along which to gather.

        Returns:
            The gathered tensor.
        """
        if self._enabled:
            if name not in self._sharded:
                logger.warning(f"Tensor '{name}' gathered without being sharded")
            self._gathered.add(name)

        return sp_gather(tensor, dim)

    def validate(self) -> None:
        """Validate that all sharded tensors were gathered.

        Raises:
            ValueError: If any sharded tensor was not gathered.
        """
        unmatched = self._sharded - self._gathered
        if unmatched:
            raise ValueError(
                f"The following tensors were sharded but not gathered: {unmatched}. "
                f"This may indicate a bug in the model's SP implementation."
            )


# Global validator instance for convenience
_global_validator = ShardingValidator()


def get_sharding_validator() -> ShardingValidator:
    """Get the global sharding validator instance.

    Returns:
        The global ShardingValidator.
    """
    return _global_validator