forward_context.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, NamedTuple
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
14
from vllm.logger import init_logger
15
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
16
from vllm.v1.worker.ubatch_utils import UBatchSlices
17

18
19
20
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

21
22
23
24
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
25
forward_start_time: float = 0
26
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
27
batchsize_forward_time: defaultdict = defaultdict(list)
28
29


30
31
32
33
34
35
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
36

37
    num_tokens: int
38
    num_reqs: int | None = None
39
    """
40
41
42
43
44
45
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
46
    """
47
48
49
50
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
51

52
    def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
53
        """
54
55
        Return a relaxed version of current batch descriptor that is still compatible
        with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
56
        """
57
        return BatchDescriptor(
58
            self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
59
        )
60
61


62
63
64
65
66
67
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
68
69
70
71
72

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


73
74
75
76
77
78
79
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
80
81
82
83
84
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
85
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
86
87
88
89
90
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


91
92
@dataclass
class DPMetadata:
93
    max_tokens_across_dp_cpu: torch.Tensor
94
95
96
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
97
    local_sizes: list[int] | None = None
98

99
    @staticmethod
100
    def make(
101
102
        parallel_config: ParallelConfig,
        num_tokens: int,
103
        num_tokens_across_dp_cpu: torch.Tensor,
104
    ) -> "DPMetadata":
105
        assert num_tokens_across_dp_cpu is not None
106
107
        assert parallel_config.data_parallel_size > 1
        dp_rank = parallel_config.data_parallel_rank
108
        batchsize = num_tokens
109

110
111
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
112
113
114
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
115
116
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
117

118
    @contextmanager
119
120
121
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
138
139
140
141
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
142
            max_chunk_size_per_rank: The max number of tokens each rank is
143
144
145
146
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
147
148
149
150
151
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
152
153
154
155
156
157
158
159
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
160
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
161
162
163
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
164
165
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
166
167
168
169
170
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

171
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
172
        assert self.local_sizes is not None
173
174
        return self.local_sizes

175
176
177
178
179
180
181
182
183
184
185
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

186

187
188
@dataclass
class ForwardContext:
189
    # copy from vllm_config.compilation_config.static_forward_context
190
    no_compile_layers: dict[str, Any]
191
192
193
    """
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
194
195
196
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
197
    """
198
    attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]]
199
200
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
201
    # set dynamically for each forward pass
202
    dp_metadata: DPMetadata | None = None
203
204
205
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
206
    batch_descriptor: BatchDescriptor | None = None
207

208
    ubatch_slices: UBatchSlices | None = None
209

210
    def __post_init__(self):
211
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
212
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
213
        )
214
215


216
_forward_context: ForwardContext | None = None
217
218
219


def get_forward_context() -> ForwardContext:
220
    """Get the current forward context."""
221
222
    assert _forward_context is not None, (
        "Forward context is not set. "
223
224
        "Please use `set_forward_context` to set the forward context."
    )
225
226
227
    return _forward_context


228
229
230
231
def is_forward_context_available() -> bool:
    return _forward_context is not None


232
def create_forward_context(
233
234
235
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
236
    dp_metadata: DPMetadata | None = None,
237
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
238
239
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
240
241
242
243
244
245
246
247
248
249
):
    return ForwardContext(
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )
250
251
252


@contextmanager
253
def override_forward_context(forward_context: ForwardContext | None):
254
255
256
257
258
259
260
261
262
263
264
265
266
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


267
@contextmanager
268
def set_forward_context(
269
270
271
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
272
273
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
274
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
275
276
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
277
):
278
    """A context manager that stores the current forward context,
279
280
281
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
282
    global forward_start_time
283
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
284
285
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
286

287
    dp_metadata: DPMetadata | None = None
288
    if vllm_config.parallel_config.data_parallel_size > 1 and (
289
290
        attn_metadata is not None or num_tokens is not None
    ):
291
292
293
294
295
296
297
298
299
300
301
302
303
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
            _, num_tokens_across_dp = coordinate_batch_across_dp(
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
                allow_dp_padding=False,
            )
            assert num_tokens_across_dp is not None
304
        dp_metadata = DPMetadata.make(
305
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
306
307
        )

308
309
310
311
312
313
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

314
315
316
317
318
319
320
321
322
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
    )
323

324
    try:
325
326
        with override_forward_context(forward_context):
            yield
327
    finally:
328
329
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
330
            batchsize = num_tokens
331
332
333
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
334
            from vllm.platforms import current_platform
335

336
337
338
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
339
340
            now = time.perf_counter()
            # time measurement is in milliseconds
341
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
342
343
344
345
346
347
348
349
350
351
352
353
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
354
355
356
357
358
359
360
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )