"vllm/vscode:/vscode.git/clone" did not exist on "2bc4be4e32a42a439f7aad3752b96a20e7c34938"
forward_context.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
9

10
import torch
11
import torch.distributed as dist
12

13
import vllm.envs as envs
14
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
18

19
20
21
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

22
23
24
25
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
26
forward_start_time: float = 0
27
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
28
batchsize_forward_time: defaultdict = defaultdict(list)
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
    num_tokens: int
    uniform_decode: bool = False
    """
    False can also be used for an uniform decode batch to dispatch to the 
    cudagraph supporting non-uniform batches.
    """

    @property
    def non_uniform(self) -> "BatchDescriptor":
        """
        Return a non-uniform version of current batch descriptor.
        """
        return BatchDescriptor(self.num_tokens, uniform_decode=False)


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
                                      max_num_tokens: int,
                                      chunk_idx: int) -> list[int]:
    dp_size = len(num_tokens_across_dp_cpu)

    local_size = [-1] * dp_size
    for i in range(dp_size):
        dp_tokens = num_tokens_across_dp_cpu[i]
        local_size[i] = min(max_num_tokens,
                            dp_tokens - (max_num_tokens * chunk_idx))
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


67
68
@dataclass
class DPMetadata:
69
    max_tokens_across_dp_cpu: torch.Tensor
70
    cu_tokens_across_dp_cpu: torch.Tensor
71
    local_sizes: Optional[list[int]] = None
72

73
74
75
76
77
78
79
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
80
81
82
83
84
85
86
87
88
89
90
91
92
        from vllm.distributed.parallel_state import get_dp_group
        device = current_platform.device_type
        group = get_dp_group().device_group

        # Transfering this tensor from GPU to CPU will introduce a GPU sync
        # point that could adversely affect performance of vllm with asynch
        # scheduling. This environment variable exists to quickly disable
        # this optimization if we run into this case.
        if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
            logger.info_once(
                "Using CPU all reduce to syncronize DP padding between ranks.")
            device = "cpu"
            group = get_dp_group().cpu_group
93
94
95
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
96
                                         device=device,
97
                                         dtype=torch.int32)
98
99
        dist.all_reduce(num_tokens_tensor, group=group)
        return num_tokens_tensor.cpu()
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    @staticmethod
    def should_ubatch_across_dp(
            should_ubatch: bool, orig_num_tokens_per_ubatch: int,
            padded_num_tokens_per_ubatch: int, dp_size: int,
            dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
        """
        1. Decides if each DP rank is going to microbatch. Either all ranks
        run with microbatching or none of them do. If this function decides
        not to run with microbatching. It will "abort" meaning that no padding
        information will be returned to the caller. It will return (False, None)

        2. Determines the total number of tokens that each rank will run.
        All ranks will be padded out so that the run with the same number
        of tokens

        Returns: tuple[
            should_ubatch: Are all DP ranks going to microbatch
            num_tokens_after_padding: A tensor containing the total number of
            tokens per-microbatch for each DP rank including padding. Will be
            None if should_ubatch if False
        ]
        """

        device = current_platform.device_type
        tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
        tensor[0][dp_rank] = orig_num_tokens_per_ubatch
        tensor[1][dp_rank] = padded_num_tokens_per_ubatch
        tensor[2][dp_rank] = 1 if should_ubatch else 0

        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(tensor, group=get_dp_group().device_group)

        result: bool = bool(torch.all(tensor[2] == 1).item())
        if not result:
            return result, None

        orig_num_tokens_tensor = tensor[0, :]
        padded_num_tokens_tensor = tensor[1, :]

        orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
        padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
        if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
            logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
                         padded_max_num_tokens)
            return False, None
        return result, padded_num_tokens_tensor.cpu()

148
    @staticmethod
149
150
151
152
153
154
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":
155
156
157
158
159
160
161
162
163
164
165
166
167

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

168
169
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
170
171
        assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
                == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
172
173
174
175
176
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
177
178
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
                          num_tokens_across_dp)
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    @contextmanager
    def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
        number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
        to determine the chunk-wise split.

        `self.local_sizes` is only valid inside the context.

        Args:
            max_chunk_size_per_rank: The max number of tokens each rank is 
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        cu_sizes = self.cu_tokens_across_dp_cpu
        num_tokens_across_dp_cpu = [
            (cu_sizes[i] -
             cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
            for i in range(len(cu_sizes))
        ]
        self.local_sizes = _compute_chunked_local_num_tokens(
            num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
        return self.local_sizes

222

223
224
@dataclass
class ForwardContext:
225
    # copy from vllm_config.compilation_config.static_forward_context
226
    no_compile_layers: dict[str, Any]
227
228
229
230
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
231
232
233
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
234
    """
235
236
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
                         list[dict[str, "AttentionMetadata"]]]
237
238
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
239
240
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
241
242
243
244
245
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
    batch_descriptor: Optional[BatchDescriptor] = None

246
247
    ubatch_slices: Optional[UBatchSlices] = None

248
    def __post_init__(self):
249
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), \
250
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
251
252
253
254
255
256


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
257
    """Get the current forward context."""
258
259
260
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
261
262
263
    return _forward_context


264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def create_forward_context(
        attn_metadata: Any,
        vllm_config: VllmConfig,
        virtual_engine: int = 0,
        dp_metadata: Optional[DPMetadata] = None,
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
        batch_descriptor: Optional[BatchDescriptor] = None,
        ubatch_slices: Optional[UBatchSlices] = None):
    return ForwardContext(no_compile_layers=vllm_config.compilation_config.
                          static_forward_context,
                          virtual_engine=virtual_engine,
                          attn_metadata=attn_metadata,
                          dp_metadata=dp_metadata,
                          cudagraph_runtime_mode=cudagraph_runtime_mode,
                          batch_descriptor=batch_descriptor,
                          ubatch_slices=ubatch_slices)


@contextmanager
def override_forward_context(forward_context: Optional[ForwardContext]):
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


297
@contextmanager
298
def set_forward_context(
299
300
301
302
303
304
        attn_metadata: Any,
        vllm_config: VllmConfig,
        virtual_engine: int = 0,
        num_tokens: Optional[int] = None,
        num_tokens_across_dp: Optional[torch.Tensor] = None,
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
305
306
        batch_descriptor: Optional[BatchDescriptor] = None,
        ubatch_slices: Optional[UBatchSlices] = None):
307
    """A context manager that stores the current forward context,
308
309
310
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
311
    global forward_start_time
312
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
313
314
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
315

316
    dp_metadata: Optional[DPMetadata] = None
317
318
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
319
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
320
321
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)
322

323
324
325
326
    forward_context = create_forward_context(attn_metadata, vllm_config,
                                             virtual_engine, dp_metadata,
                                             cudagraph_runtime_mode,
                                             batch_descriptor, ubatch_slices)
327

328
    try:
329
330
        with override_forward_context(forward_context):
            yield
331
    finally:
332
333
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
334
            if hasattr(attn_metadata, "num_prefill_tokens"):
335
                # for v0 attention backends
336
337
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
338
339
            else:
                # for v1 attention backends
340
                batchsize = num_tokens
341
342
343
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
344
345
346
347
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)