forward_context.py 16.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from typing import Any, NamedTuple
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.v1.attention.backend import AttentionMetadata
17
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
18
from vllm.v1.worker.ubatch_utils import UBatchSlices
19
20
21
22
23

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
24
forward_start_time: float = 0
25
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
26
batchsize_forward_time: defaultdict = defaultdict(list)
27
28


29
30
31
32
33
34
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
35

36
    num_tokens: int
37
    num_reqs: int | None = None
38
    """
39
40
41
42
43
44
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
45
    """
46
47
48
49
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
50
51
52
53
54
55
56
57
    num_active_loras: int = 0
    """
    Number of distinct active LoRA adapters in this batch.
    When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
    are captured for each num_active_loras value. This allows kernels
    (like fused_moe_lora) whose grid size depends on num_active_loras
    to be properly captured.
    """
58

59
    def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
60
        """
61
62
        Return a relaxed version of current batch descriptor that is still compatible
        with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
63
        """
64
        return BatchDescriptor(
65
66
67
68
69
            self.num_tokens,
            num_reqs=None,
            uniform=False,
            has_lora=self.has_lora,
            num_active_loras=self.num_active_loras,
70
        )
71
72


73
74
75
76
77
78
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
79
80
81
82
83

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


84
85
86
87
88
89
90
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
91
92
93
94
95
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
96
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
97
98
99
100
101
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


102
103
@dataclass
class DPMetadata:
104
    max_tokens_across_dp_cpu: torch.Tensor
105
106
107
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
108
    local_sizes: list[int] | None = None
109

110
    @staticmethod
111
    def make(
112
113
        parallel_config: ParallelConfig,
        num_tokens: int,
114
        num_tokens_across_dp_cpu: torch.Tensor,
115
    ) -> "DPMetadata":
116
        assert num_tokens_across_dp_cpu is not None
117
        assert parallel_config.data_parallel_size > 1
118
        assert parallel_config.is_moe_model is not False
119
        dp_rank = parallel_config.data_parallel_rank
120
        batchsize = num_tokens
121

122
123
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
124
125
126
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
127
128
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
129

130
    @contextmanager
131
132
133
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
150
151
152
153
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
154
            max_chunk_size_per_rank: The max number of tokens each rank is
155
156
157
158
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
159
160
161
162
163
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
164
165
166
167
168
169
170
171
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
172
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
173
174
175
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
176
177
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
178
179
180
181
182
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

183
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
184
        assert self.local_sizes is not None
185
186
        return self.local_sizes

187
188
189
190
191
192
193
194
195
196
197
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

198

199
200
@dataclass
class ForwardContext:
201
    # copy from vllm_config.compilation_config.static_forward_context
202
    no_compile_layers: dict[str, Any]
203
    attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
204
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
205
    """
206
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
207
    attention layer to its attention metadata
208
209
210
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
211
    """
212
213
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
214
    # set dynamically for each forward pass
215
    dp_metadata: DPMetadata | None = None
216
217
218
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
219
    batch_descriptor: BatchDescriptor | None = None
220

221
    ubatch_slices: UBatchSlices | None = None
222

223
224
225
    # If True, bypass the compiled model call, e.g. by using .forward() directly
    skip_compiled: bool = False

226
227
228
229
230
231
    # For torch.compile cold start times, we need to avoid hard-coding
    # any strings into the graph. Right now, the vllm.moe_forward
    # and vllm.moe_forward_shared custom operators hard-code strings into
    # the graph.
    #
    # The workaround is to store a list of the strings that each of those
232
233
    # custom ops needs in the ForwardContext (all_moe_layers)
    # as well as a counter (moe_layer_index).
234
    # The ForwardContext object is alive for the duration of the forward pass.
235
236
    # When the custom op needs a layer string, get the next string
    # from all_moe_layers and increment the counter.
237
238
239
240
241
242
243
244
245
246
247
248
249
    #
    # This assumes that the custom operators will always be executed in
    # order and that torch.compile will not try to reorder these
    # operations with respect to each other.
    #
    # TODO(https://github.com/vllm-project/vllm/issues/31985):
    # There are longer-term solutions, like unwrapping the moe custom operator,
    # that aren't ready yet.
    # We could also treat the string as a "symbolic input" to the graph but
    # the PyTorch-side bits for that aren't ready yet either.
    #
    # If this value is None (like in some tests), then we end up baking the string
    # into the graph. Otherwise, the moe custom ops will pop a string from this list.
250
251
    all_moe_layers: list[str] | None = None
    moe_layer_index: int = 0
252

253
254
    additional_kwargs: dict[str, Any] = field(default_factory=dict)

255
    def __post_init__(self):
256
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
257
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
258
        )
259
260


261
_forward_context: ForwardContext | None = None
262
263
264


def get_forward_context() -> ForwardContext:
265
    """Get the current forward context."""
266
267
    assert _forward_context is not None, (
        "Forward context is not set. "
268
269
        "Please use `set_forward_context` to set the forward context."
    )
270
271
272
    return _forward_context


273
274
275
276
def is_forward_context_available() -> bool:
    return _forward_context is not None


277
def create_forward_context(
278
279
280
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
281
    dp_metadata: DPMetadata | None = None,
282
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
283
284
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
285
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
286
    additional_kwargs: dict[str, Any] | None = None,
287
    skip_compiled: bool = False,
288
):
289
    if vllm_config.compilation_config.fast_moe_cold_start:
290
        all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
291
292
293
    else:
        all_moe_layers = None

294
    return ForwardContext(
295
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
296
        all_moe_layers=all_moe_layers,
297
298
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
299
        slot_mapping=slot_mapping or {},
300
301
302
303
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
304
        skip_compiled=skip_compiled,
305
        additional_kwargs=additional_kwargs or {},
306
    )
307
308
309


@contextmanager
310
def override_forward_context(forward_context: ForwardContext | None):
311
312
313
314
315
316
317
318
319
320
321
322
323
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


324
@contextmanager
325
def set_forward_context(
326
327
328
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
329
330
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
331
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
332
333
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
334
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
335
    skip_compiled: bool = False,
336
):
337
    """A context manager that stores the current forward context,
338
339
340
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
341
    global forward_start_time
342
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
343
344
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
345

346
    dp_metadata: DPMetadata | None = None
347
348
349
350
    if (
        vllm_config.parallel_config.data_parallel_size > 1
        and vllm_config.parallel_config.is_moe_model is not False
        and (attn_metadata is not None or num_tokens is not None)
351
    ):
352
353
354
355
356
357
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
358
            _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
359
360
361
362
363
364
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
                allow_dp_padding=False,
            )
            assert num_tokens_across_dp is not None
365
        dp_metadata = DPMetadata.make(
366
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
367
368
        )

369
370
371
372
373
374
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

375
376
377
378
    additional_kwargs = current_platform.set_additional_forward_context(
        attn_metadata=attn_metadata,
        vllm_config=vllm_config,
        virtual_engine=virtual_engine,
379
        dp_metadata=dp_metadata,
380
381
382
383
384
385
386
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )

387
388
389
390
391
392
393
394
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
395
        slot_mapping,
396
        additional_kwargs,
397
        skip_compiled,
398
    )
399

400
    try:
401
402
        with override_forward_context(forward_context):
            yield
403
    finally:
404
405
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
406
            batchsize = num_tokens
407
408
409
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
410
411
412
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
413
414
            now = time.perf_counter()
            # time measurement is in milliseconds
415
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
416
417
418
419
420
421
422
423
424
425
426
427
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
428
429
430
431
432
433
434
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )