forward_context.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
import time
6
from collections import defaultdict
7
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
from typing import Any, NamedTuple
10

11
12
import torch

13
import vllm.envs as envs
14
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
15
from vllm.logger import init_logger
16

17
from vllm.platforms import current_platform
18
from vllm.v1.attention.backend import AttentionMetadata
19
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
20
from vllm.v1.worker.ubatch_utils import UBatchSlices
21

22
23
24
25
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
26
forward_start_time: float = 0
27
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
28
batchsize_forward_time: defaultdict = defaultdict(list)
29
30


zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
35
36
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
37

zhuwenwen's avatar
zhuwenwen committed
38
    num_tokens: int
39
    num_reqs: int | None = None
40
    """
41
42
43
44
45
46
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
47
    """
48
    has_lora: bool = False
zhuwenwen's avatar
zhuwenwen committed
49
    """
50
    Whether this batch has active LoRA adapters.
zhuwenwen's avatar
zhuwenwen committed
51
52
    """

53
    def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
zhuwenwen's avatar
zhuwenwen committed
54
        """
55
56
        Return a relaxed version of current batch descriptor that is still compatible
        with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
zhuwenwen's avatar
zhuwenwen committed
57
        """
58
        return BatchDescriptor(
59
            self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
60
        )
zhuwenwen's avatar
zhuwenwen committed
61
62


63
64
65
66
67
68
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
69
70
71
72
73

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


74
75
76
77
78
79
80
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
81
82
83
84
85
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
86
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
zhuwenwen's avatar
zhuwenwen committed
87
88
89
90
91
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


92
93
@dataclass
class DPMetadata:
94
    max_tokens_across_dp_cpu: torch.Tensor
95
96
97
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
98
    local_sizes: list[int] | None = None
99
100

    @staticmethod
101
    def make(
102
103
        parallel_config: ParallelConfig,
        num_tokens: int,
104
        num_tokens_across_dp_cpu: torch.Tensor,
105
    ) -> "DPMetadata":
106
        assert num_tokens_across_dp_cpu is not None
107
        assert parallel_config.data_parallel_size > 1
108
        assert parallel_config.is_moe_model is not False
109
        dp_rank = parallel_config.data_parallel_rank
110
        batchsize = num_tokens
111

112
113
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
114
115
116
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
117
118
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
119

120
    @contextmanager
121
122
123
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
140
141
142
143
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
144
            max_chunk_size_per_rank: The max number of tokens each rank is
145
146
147
148
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
149
150
151
152
153
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
154
155
156
157
158
159
160
161
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
162
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
163
164
165
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
166
167
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
168
169
170
171
172
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

173
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
174
        assert self.local_sizes is not None
175
176
        return self.local_sizes

177
178
179
180
181
182
183
184
185
186
187
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

188

189
190
@dataclass
class ForwardContext:
191
    # copy from vllm_config.compilation_config.static_forward_context
192
    no_compile_layers: dict[str, Any]
193
    attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
194
195
196
    """
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
197
198
199
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
200
    """
201
202
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
203
    # set dynamically for each forward pass
204
    dp_metadata: DPMetadata | None = None
205
206
207
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
208
    batch_descriptor: BatchDescriptor | None = None
209

210
    ubatch_slices: UBatchSlices | None = None
211

212
213
    additional_kwargs: dict[str, Any] = field(default_factory=dict)

214
    def __post_init__(self):
215
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
216
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
217
        )
218
219


220
_forward_context: ForwardContext | None = None
221
222
223


def get_forward_context() -> ForwardContext:
224
    """Get the current forward context."""
225
226
    assert _forward_context is not None, (
        "Forward context is not set. "
227
228
        "Please use `set_forward_context` to set the forward context."
    )
229
230
231
    return _forward_context


232
233
234
235
def is_forward_context_available() -> bool:
    return _forward_context is not None


236
def create_forward_context(
237
238
239
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
240
    dp_metadata: DPMetadata | None = None,
241
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
242
243
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
244
    additional_kwargs: dict[str, Any] | None = None,
245
246
247
248
249
250
251
252
253
):
    return ForwardContext(
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
254
        additional_kwargs=additional_kwargs or {},
255
    )
256
257
258


@contextmanager
259
def override_forward_context(forward_context: ForwardContext | None):
260
261
262
263
264
265
266
267
268
269
270
271
272
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


273
@contextmanager
274
def set_forward_context(
275
276
277
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
278
279
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
280
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
281
282
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
283
):
284
    """A context manager that stores the current forward context,
285
286
287
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
288
    global forward_start_time
289
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
290
291
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
292

293
    dp_metadata: DPMetadata | None = None
294
    if vllm_config.parallel_config.data_parallel_size > 1 and (
295
296
        attn_metadata is not None or num_tokens is not None
    ):
297
298
299
300
301
302
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
303
            _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
304
305
306
307
308
309
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
                allow_dp_padding=False,
            )
            assert num_tokens_across_dp is not None
310
        dp_metadata = DPMetadata.make(
311
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
312
313
        )

314
315
316
317
318
319
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

320
321
322
323
    additional_kwargs = current_platform.set_additional_forward_context(
        attn_metadata=attn_metadata,
        vllm_config=vllm_config,
        virtual_engine=virtual_engine,
324
        dp_metadata=dp_metadata,
325
326
327
328
329
330
331
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )

332
333
334
335
336
337
338
339
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
340
        additional_kwargs,
341
    )
342

343
    try:
344
345
        with override_forward_context(forward_context):
            yield
346
    finally:
347
348
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
349
            batchsize = num_tokens
350
351
352
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
353
354
355
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
356
357
            now = time.perf_counter()
            # time measurement is in milliseconds
358
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
359
360
361
362
363
364
365
366
367
368
369
370
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
371
372
373
374
375
376
377
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )
378
                    
379
380
381
382
383
384
385
386
387
388
_profiling: bool = False

@contextmanager
def set_profilling(profiling):
    global _profiling
    _profiling = profiling


def get_profilling() -> bool:
    global _profiling
389
    return _profiling