forward_context.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import Any, NamedTuple
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm.attention.backends.abstract import AttentionMetadata
14
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
15
from vllm.logger import init_logger
16
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
17
from vllm.v1.worker.ubatch_utils import UBatchSlices
18
19
20
21
22

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
23
forward_start_time: float = 0
24
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
25
batchsize_forward_time: defaultdict = defaultdict(list)
26
27


28
29
30
31
32
33
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
34

35
    num_tokens: int
36
    num_reqs: int | None = None
37
    """
38
39
40
41
42
43
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
44
    """
45
46
47
48
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
49

50
    def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
51
        """
52
53
        Return a relaxed version of current batch descriptor that is still compatible
        with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
54
        """
55
        return BatchDescriptor(
56
            self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
57
        )
58
59


60
61
62
63
64
65
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
66
67
68
69
70

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


71
72
73
74
75
76
77
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
78
79
80
81
82
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
83
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
84
85
86
87
88
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


89
90
@dataclass
class DPMetadata:
91
    max_tokens_across_dp_cpu: torch.Tensor
92
93
94
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
95
    local_sizes: list[int] | None = None
96

97
    @staticmethod
98
    def make(
99
100
        parallel_config: ParallelConfig,
        num_tokens: int,
101
        num_tokens_across_dp_cpu: torch.Tensor,
102
    ) -> "DPMetadata":
103
        assert num_tokens_across_dp_cpu is not None
104
105
        assert parallel_config.data_parallel_size > 1
        dp_rank = parallel_config.data_parallel_rank
106
        batchsize = num_tokens
107

108
109
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
110
111
112
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
113
114
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
115

116
    @contextmanager
117
118
119
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
136
137
138
139
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
140
            max_chunk_size_per_rank: The max number of tokens each rank is
141
142
143
144
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
145
146
147
148
149
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
150
151
152
153
154
155
156
157
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
158
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
159
160
161
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
162
163
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
164
165
166
167
168
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

169
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
170
        assert self.local_sizes is not None
171
172
        return self.local_sizes

173
174
175
176
177
178
179
180
181
182
183
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

184

185
186
@dataclass
class ForwardContext:
187
    # copy from vllm_config.compilation_config.static_forward_context
188
    no_compile_layers: dict[str, Any]
189
190
191
    """
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
192
193
194
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
195
    """
196
    attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
197
198
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
199
    # set dynamically for each forward pass
200
    dp_metadata: DPMetadata | None = None
201
202
203
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
204
    batch_descriptor: BatchDescriptor | None = None
205

206
    ubatch_slices: UBatchSlices | None = None
207

208
    def __post_init__(self):
209
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
210
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
211
        )
212
213


214
_forward_context: ForwardContext | None = None
215
216
217


def get_forward_context() -> ForwardContext:
218
    """Get the current forward context."""
219
220
    assert _forward_context is not None, (
        "Forward context is not set. "
221
222
        "Please use `set_forward_context` to set the forward context."
    )
223
224
225
    return _forward_context


226
227
228
229
def is_forward_context_available() -> bool:
    return _forward_context is not None


230
def create_forward_context(
231
232
233
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
234
    dp_metadata: DPMetadata | None = None,
235
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
236
237
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
238
239
240
241
242
243
244
245
246
247
):
    return ForwardContext(
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )
248
249
250


@contextmanager
251
def override_forward_context(forward_context: ForwardContext | None):
252
253
254
255
256
257
258
259
260
261
262
263
264
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


265
@contextmanager
266
def set_forward_context(
267
268
269
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
270
271
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
272
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
273
274
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
275
):
276
    """A context manager that stores the current forward context,
277
278
279
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
280
    global forward_start_time
281
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
282
283
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
284

285
    dp_metadata: DPMetadata | None = None
286
    if vllm_config.parallel_config.data_parallel_size > 1 and (
287
288
        attn_metadata is not None or num_tokens is not None
    ):
289
290
291
292
293
294
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
295
            _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
296
297
298
299
300
301
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
                allow_dp_padding=False,
            )
            assert num_tokens_across_dp is not None
302
        dp_metadata = DPMetadata.make(
303
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
304
305
        )

306
307
308
309
310
311
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

312
313
314
315
316
317
318
319
320
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
    )
321

322
    try:
323
324
        with override_forward_context(forward_context):
            yield
325
    finally:
326
327
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
328
            batchsize = num_tokens
329
330
331
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
332
            from vllm.platforms import current_platform
333

334
335
336
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
337
338
            now = time.perf_counter()
            # time measurement is in milliseconds
339
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
340
341
342
343
344
345
346
347
348
349
350
351
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
352
353
354
355
356
357
358
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )