"vllm/vscode:/vscode.git/clone" did not exist on "88d7bdbd2337917fbdd65bdfe33e6af7ecdb45cd"
forward_context.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
9

10
import torch
11
import torch.distributed as dist
12

13
import vllm.envs as envs
14
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
18

19
20
21
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

22
23
24
25
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
26
forward_start_time: float = 0
27
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
28
batchsize_forward_time: defaultdict = defaultdict(list)
29
30


31
32
33
34
35
36
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    num_tokens: int
    uniform_decode: bool = False
    """
    False can also be used for an uniform decode batch to dispatch to the 
    cudagraph supporting non-uniform batches.
    """

    @property
    def non_uniform(self) -> "BatchDescriptor":
        """
        Return a non-uniform version of current batch descriptor.
        """
        return BatchDescriptor(self.num_tokens, uniform_decode=False)


53
54
55
56
57
58
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
59
60
61
62
63

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


64
65
66
67
68
69
70
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
71
72
73
74
75
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
76
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
77
78
79
80
81
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


82
83
@dataclass
class DPMetadata:
84
    max_tokens_across_dp_cpu: torch.Tensor
85
86
87
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
88
    local_sizes: Optional[list[int]] = None
89

90
    @staticmethod
91
92
93
    def num_tokens_across_dp(
        num_tokens: int, dp_size: int, dp_rank: int
    ) -> torch.Tensor:
94
95
96
97
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
98
        from vllm.distributed.parallel_state import get_dp_group
99

100
101
102
103
104
105
106
107
108
        device = current_platform.device_type
        group = get_dp_group().device_group

        # Transfering this tensor from GPU to CPU will introduce a GPU sync
        # point that could adversely affect performance of vllm with asynch
        # scheduling. This environment variable exists to quickly disable
        # this optimization if we run into this case.
        if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
            logger.info_once(
109
110
                "Using CPU all reduce to syncronize DP padding between ranks."
            )
111
112
            device = "cpu"
            group = get_dp_group().cpu_group
113
114
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
115
116
117
        num_tokens_tensor = torch.tensor(
            num_tokens_across_dp, device=device, dtype=torch.int32
        )
118
119
        dist.all_reduce(num_tokens_tensor, group=group)
        return num_tokens_tensor.cpu()
120

121
122
123
124
125
126
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
127
128
129
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
130
131
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

132
133
    @staticmethod
    def should_ubatch_across_dp(
134
135
136
137
138
139
        should_ubatch: bool,
        orig_num_tokens_per_ubatch: int,
        padded_num_tokens_per_ubatch: int,
        dp_size: int,
        dp_rank: int,
    ) -> tuple[bool, Optional[torch.Tensor]]:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        """
        1. Decides if each DP rank is going to microbatch. Either all ranks
        run with microbatching or none of them do. If this function decides
        not to run with microbatching. It will "abort" meaning that no padding
        information will be returned to the caller. It will return (False, None)

        2. Determines the total number of tokens that each rank will run.
        All ranks will be padded out so that the run with the same number
        of tokens

        Returns: tuple[
            should_ubatch: Are all DP ranks going to microbatch
            num_tokens_after_padding: A tensor containing the total number of
            tokens per-microbatch for each DP rank including padding. Will be
            None if should_ubatch if False
        ]
        """

        device = current_platform.device_type
        tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
        tensor[0][dp_rank] = orig_num_tokens_per_ubatch
        tensor[1][dp_rank] = padded_num_tokens_per_ubatch
        tensor[2][dp_rank] = 1 if should_ubatch else 0

        from vllm.distributed.parallel_state import get_dp_group
165

166
167
168
169
170
171
172
173
174
175
176
177
        dist.all_reduce(tensor, group=get_dp_group().device_group)

        result: bool = bool(torch.all(tensor[2] == 1).item())
        if not result:
            return result, None

        orig_num_tokens_tensor = tensor[0, :]
        padded_num_tokens_tensor = tensor[1, :]

        orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
        padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
        if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
178
179
180
            logger.debug(
                "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
            )
181
182
183
            return False, None
        return result, padded_num_tokens_tensor.cpu()

184
    @staticmethod
185
    def make(
186
187
188
        parallel_config: ParallelConfig,
        attn_metadata: Any,
        num_tokens: int,
189
        num_tokens_across_dp_cpu: Optional[torch.Tensor] = None,
190
    ) -> "DPMetadata":
191
192
193
        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
194
        if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"):
195
            # for v0 attention backends
196
197
198
            batchsize = (
                attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
            )
199
200
201
202
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

203
204
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
205
206
207
208
        assert (
            num_tokens_across_dp_cpu is None
            or num_tokens_across_dp_cpu[dp_rank] == batchsize
        ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
209
210
        if num_tokens_across_dp_cpu is None:
            num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
211
212
                batchsize, dp_size, dp_rank
            )
213
214
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
215

216
    @contextmanager
217
218
219
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
236
237
238
239
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
240
            max_chunk_size_per_rank: The max number of tokens each rank is
241
242
243
244
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
245
246
247
248
249
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
250
251
252
253
254
255
256
257
258
259
260
261
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
        Context mamager for setting self.local_sizes. Same as self.chunked_sizes
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
262
263
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
264
265
266
267
268
269
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
270
        assert self.local_sizes is not None
271
272
        return self.local_sizes

273

274
275
@dataclass
class ForwardContext:
276
    # copy from vllm_config.compilation_config.static_forward_context
277
    no_compile_layers: dict[str, Any]
278
279
280
281
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
282
283
284
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
285
    """
286
287
288
289
290
    attn_metadata: Union[
        "AttentionMetadata",
        dict[str, "AttentionMetadata"],
        list[dict[str, "AttentionMetadata"]],
    ]
291
292
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
293
294
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
295
296
297
298
299
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
    batch_descriptor: Optional[BatchDescriptor] = None

300
301
    ubatch_slices: Optional[UBatchSlices] = None

302
    def __post_init__(self):
303
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
304
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
305
        )
306
307
308
309
310
311


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
312
    """Get the current forward context."""
313
314
    assert _forward_context is not None, (
        "Forward context is not set. "
315
316
        "Please use `set_forward_context` to set the forward context."
    )
317
318
319
    return _forward_context


320
def create_forward_context(
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    dp_metadata: Optional[DPMetadata] = None,
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
    ubatch_slices: Optional[UBatchSlices] = None,
):
    return ForwardContext(
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354


@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


355
@contextmanager
356
def set_forward_context(
357
358
359
360
361
362
363
364
365
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    num_tokens: Optional[int] = None,
    num_tokens_across_dp: Optional[torch.Tensor] = None,
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
    ubatch_slices: Optional[UBatchSlices] = None,
):
366
    """A context manager that stores the current forward context,
367
368
369
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
370
    global forward_start_time
371
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
372
373
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
374

375
    dp_metadata: Optional[DPMetadata] = None
376
    if vllm_config.parallel_config.data_parallel_size > 1 and (
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        attn_metadata is not None or num_tokens is not None
    ):
        dp_metadata = DPMetadata.make(
            vllm_config.parallel_config,
            attn_metadata,
            num_tokens or 0,
            num_tokens_across_dp,
        )

    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
    )
395

396
    try:
397
398
        with override_forward_context(forward_context):
            yield
399
    finally:
400
401
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
402
            if hasattr(attn_metadata, "num_prefill_tokens"):
403
                # for v0 attention backends
404
405
406
                batchsize = (
                    attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
                )
407
408
            else:
                # for v1 attention backends
409
                batchsize = num_tokens
410
411
412
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
413
            from vllm.platforms import current_platform
414

415
416
417
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
418
419
            now = time.perf_counter()
            # time measurement is in milliseconds
420
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
421
422
423
424
425
426
427
428
429
430
431
432
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
433
434
435
436
437
438
439
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )