"tools/vscode:/vscode.git/clone" did not exist on "ec68d53b2b75eb5480270c67676b126079998f5a"
forward_context.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from typing import Any
9

10
11
import torch

12
import vllm.envs as envs
13
import vllm.ir
14
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.v1.attention.backend import AttentionMetadata
18
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
19
from vllm.v1.worker.ubatch_utils import UBatchSlices
20
21
22
23
24

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
25
forward_start_time: float = 0
26
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
27
batchsize_forward_time: defaultdict = defaultdict(list)
28
29


30
31
@dataclass(frozen=True)
class BatchDescriptor:
32
33
34
35
36
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
37

38
    num_tokens: int
39
    num_reqs: int | None = None
40
    """
41
42
43
44
45
46
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
47
    """
48
49
50
51
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
52
53
54
55
56
57
58
59
    num_active_loras: int = 0
    """
    Number of distinct active LoRA adapters in this batch.
    When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
    are captured for each num_active_loras value. This allows kernels
    (like fused_moe_lora) whose grid size depends on num_active_loras
    to be properly captured.
    """
60
61


62
63
64
65
66
67
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
68
69
70
71
72

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


73
74
@dataclass
class DPMetadata:
75
76
77
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
78
    local_sizes: list[int] | None = None
79

80
    @staticmethod
81
    def make(
82
83
        parallel_config: ParallelConfig,
        num_tokens: int,
84
        num_tokens_across_dp_cpu: torch.Tensor,
85
    ) -> "DPMetadata":
86
        assert num_tokens_across_dp_cpu is not None
87
        assert parallel_config.data_parallel_size > 1
88
        assert parallel_config.is_moe_model is not False
89
        dp_rank = parallel_config.data_parallel_rank
90
        batchsize = num_tokens
91

92
93
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
94
95
96
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
97
        return DPMetadata(num_tokens_across_dp_cpu)
98
99
100
101

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
102
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
103
104
105
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
106
107
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
108
109
110
111
112
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

113
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
114
        assert self.local_sizes is not None
115
116
        return self.local_sizes

117
118
119
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
Jiayi Yan's avatar
Jiayi Yan committed
120
    # When sp_size==1, this is just the cumulative num tokens across DP.
121
122
123
124
125
126
127
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

128

129
130
@dataclass
class ForwardContext:
131
    # copy from vllm_config.compilation_config.static_forward_context
132
    no_compile_layers: dict[str, Any]
133
    attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
134
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
135
    """
136
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
137
    attention layer to its attention metadata
138
139
140
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
141
    """
142
    # set dynamically for each forward pass
143
    dp_metadata: DPMetadata | None = None
144
145
146
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
147
    batch_descriptor: BatchDescriptor | None = None
148

149
    ubatch_slices: UBatchSlices | None = None
150

151
152
153
    # If True, bypass the compiled model call, e.g. by using .forward() directly
    skip_compiled: bool = False

154
155
156
157
158
159
    # For torch.compile cold start times, we need to avoid hard-coding
    # any strings into the graph. Right now, the vllm.moe_forward
    # and vllm.moe_forward_shared custom operators hard-code strings into
    # the graph.
    #
    # The workaround is to store a list of the strings that each of those
160
161
    # custom ops needs in the ForwardContext (all_moe_layers)
    # as well as a counter (moe_layer_index).
162
    # The ForwardContext object is alive for the duration of the forward pass.
163
164
    # When the custom op needs a layer string, get the next string
    # from all_moe_layers and increment the counter.
165
166
167
168
169
170
171
172
173
174
175
176
177
    #
    # This assumes that the custom operators will always be executed in
    # order and that torch.compile will not try to reorder these
    # operations with respect to each other.
    #
    # TODO(https://github.com/vllm-project/vllm/issues/31985):
    # There are longer-term solutions, like unwrapping the moe custom operator,
    # that aren't ready yet.
    # We could also treat the string as a "symbolic input" to the graph but
    # the PyTorch-side bits for that aren't ready yet either.
    #
    # If this value is None (like in some tests), then we end up baking the string
    # into the graph. Otherwise, the moe custom ops will pop a string from this list.
178
179
    all_moe_layers: list[str] | None = None
    moe_layer_index: int = 0
180

181
182
    additional_kwargs: dict[str, Any] = field(default_factory=dict)

183
    def __post_init__(self):
184
        assert self.cudagraph_runtime_mode.is_valid_runtime_mode(), (
185
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
186
        )
187
188


189
_forward_context: ForwardContext | None = None
190
191
192


def get_forward_context() -> ForwardContext:
193
    """Get the current forward context."""
194
195
    assert _forward_context is not None, (
        "Forward context is not set. "
196
197
        "Please use `set_forward_context` to set the forward context."
    )
198
199
200
    return _forward_context


201
202
203
204
def is_forward_context_available() -> bool:
    return _forward_context is not None


205
def create_forward_context(
206
207
    attn_metadata: Any,
    vllm_config: VllmConfig,
208
    dp_metadata: DPMetadata | None = None,
209
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
210
211
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
212
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
213
    additional_kwargs: dict[str, Any] | None = None,
214
    skip_compiled: bool = False,
215
):
216
    if vllm_config.compilation_config.fast_moe_cold_start:
217
        all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
218
219
220
    else:
        all_moe_layers = None

221
    return ForwardContext(
222
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
223
        all_moe_layers=all_moe_layers,
224
        attn_metadata=attn_metadata,
225
        slot_mapping=slot_mapping or {},
226
227
228
229
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
230
        skip_compiled=skip_compiled,
231
        additional_kwargs=additional_kwargs or {},
232
    )
233
234
235


@contextmanager
236
def override_forward_context(forward_context: ForwardContext | None):
237
238
239
240
241
242
243
244
245
246
247
248
249
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


250
@contextmanager
251
def set_forward_context(
252
253
    attn_metadata: Any,
    vllm_config: VllmConfig,
254
255
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
256
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
257
258
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
259
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
260
    skip_compiled: bool = False,
261
):
262
    """A context manager that stores the current forward context,
263
264
265
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
266
    global forward_start_time
267
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
268
269
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
270

271
    dp_metadata: DPMetadata | None = None
272
273
274
275
    if (
        vllm_config.parallel_config.data_parallel_size > 1
        and vllm_config.parallel_config.is_moe_model is not False
        and (attn_metadata is not None or num_tokens is not None)
276
    ):
277
278
279
280
281
282
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
283
            _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
284
285
286
287
288
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
            )
            assert num_tokens_across_dp is not None
289
        dp_metadata = DPMetadata.make(
290
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
291
292
        )

293
294
295
296
297
298
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

299
300
301
    additional_kwargs = current_platform.set_additional_forward_context(
        attn_metadata=attn_metadata,
        vllm_config=vllm_config,
302
        dp_metadata=dp_metadata,
303
304
305
306
307
308
309
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )

310
311
312
313
314
315
316
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
317
        slot_mapping,
318
        additional_kwargs,
319
        skip_compiled,
320
    )
321

322
    try:
323
324
325
326
327
328
329
        with (
            override_forward_context(forward_context),
            vllm_config.kernel_config.ir_op_priority.set_priority(),
            vllm.ir.enable_torch_wrap(
                vllm_config.compilation_config.ir_enable_torch_wrap
            ),
        ):
330
            yield
331
    finally:
332
333
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
334
            batchsize = num_tokens
335
336
337
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
338
339
340
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
341
342
            now = time.perf_counter()
            # time measurement is in milliseconds
343
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
344
345
346
347
348
349
350
351
352
353
354
355
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
356
357
358
359
360
361
362
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )