forward_context.py 13.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import Any, NamedTuple
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm.attention.backends.abstract import AttentionMetadata
14
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
15
from vllm.logger import init_logger
16
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
17
from vllm.v1.worker.ubatch_utils import UBatchSlices
18
19
20
21
22

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
23
forward_start_time: float = 0
24
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
25
batchsize_forward_time: defaultdict = defaultdict(list)
26
27


28
29
30
31
32
33
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
34

35
    num_tokens: int
36
    num_reqs: int | None = None
37
    """
38
39
40
41
42
43
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
44
    """
45
46
47
48
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
49

50
    def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
51
        """
52
53
        Return a relaxed version of current batch descriptor that is still compatible
        with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
54
        """
55
        return BatchDescriptor(
56
            self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
57
        )
58
59


60
61
62
63
64
65
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
66
67
68
69
70

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


71
72
73
74
75
76
77
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
78
79
80
81
82
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
83
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
84
85
86
87
88
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


89
90
@dataclass
class DPMetadata:
91
    max_tokens_across_dp_cpu: torch.Tensor
92
93
94
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
95
    local_sizes: list[int] | None = None
96

97
    @staticmethod
98
    def make(
99
100
        parallel_config: ParallelConfig,
        num_tokens: int,
101
        num_tokens_across_dp_cpu: torch.Tensor,
102
    ) -> "DPMetadata":
103
        assert num_tokens_across_dp_cpu is not None
104
        assert parallel_config.data_parallel_size > 1
105
        assert parallel_config.is_moe_model is not False
106
        dp_rank = parallel_config.data_parallel_rank
107
        batchsize = num_tokens
108

109
110
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
111
112
113
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
114
115
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
116

117
    @contextmanager
118
119
120
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
137
138
139
140
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
141
            max_chunk_size_per_rank: The max number of tokens each rank is
142
143
144
145
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
146
147
148
149
150
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
151
152
153
154
155
156
157
158
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
159
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
160
161
162
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
163
164
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
165
166
167
168
169
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

170
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
171
        assert self.local_sizes is not None
172
173
        return self.local_sizes

174
175
176
177
178
179
180
181
182
183
184
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

185

186
187
@dataclass
class ForwardContext:
188
    # copy from vllm_config.compilation_config.static_forward_context
189
    no_compile_layers: dict[str, Any]
190
    attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
191
192
193
    """
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
194
195
196
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
197
    """
198
199
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
200
    # set dynamically for each forward pass
201
    dp_metadata: DPMetadata | None = None
202
203
204
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
205
    batch_descriptor: BatchDescriptor | None = None
206

207
    ubatch_slices: UBatchSlices | None = None
208

209
    def __post_init__(self):
210
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
211
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
212
        )
213
214


215
_forward_context: ForwardContext | None = None
216
217
218


def get_forward_context() -> ForwardContext:
219
    """Get the current forward context."""
220
221
    assert _forward_context is not None, (
        "Forward context is not set. "
222
223
        "Please use `set_forward_context` to set the forward context."
    )
224
225
226
    return _forward_context


227
228
229
230
def is_forward_context_available() -> bool:
    return _forward_context is not None


231
def create_forward_context(
232
233
234
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
235
    dp_metadata: DPMetadata | None = None,
236
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
237
238
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
239
240
241
242
243
244
245
246
247
248
):
    return ForwardContext(
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )
249
250
251


@contextmanager
252
def override_forward_context(forward_context: ForwardContext | None):
253
254
255
256
257
258
259
260
261
262
263
264
265
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


266
@contextmanager
267
def set_forward_context(
268
269
270
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
271
272
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
273
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
274
275
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
276
):
277
    """A context manager that stores the current forward context,
278
279
280
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
281
    global forward_start_time
282
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
283
284
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
285

286
    dp_metadata: DPMetadata | None = None
287
    if vllm_config.parallel_config.data_parallel_size > 1 and (
288
289
        attn_metadata is not None or num_tokens is not None
    ):
290
291
292
293
294
295
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
296
            _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
297
298
299
300
301
302
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
                allow_dp_padding=False,
            )
            assert num_tokens_across_dp is not None
303
        dp_metadata = DPMetadata.make(
304
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
305
306
        )

307
308
309
310
311
312
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

313
314
315
316
317
318
319
320
321
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
    )
322

323
    try:
324
325
        with override_forward_context(forward_context):
            yield
326
    finally:
327
328
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
329
            batchsize = num_tokens
330
331
332
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
333
            from vllm.platforms import current_platform
334

335
336
337
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
338
339
            now = time.perf_counter()
            # time measurement is in milliseconds
340
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
341
342
343
344
345
346
347
348
349
350
351
352
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
353
354
355
356
357
358
359
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )