forward_context.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
import time
6
from collections import defaultdict
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
10

11
import torch
12
import torch.distributed as dist
13

14
import vllm.envs as envs
15
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
16
from vllm.logger import init_logger
17
from vllm.two_batch_overlap.forward_context import get_tbo_forward_context, set_tbo_forward_context
18

19
20
21
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

22
23
24
25
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
26
forward_start_time: float = 0
27
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
28
batchsize_forward_time: defaultdict = defaultdict(list)
29
30


zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
    num_tokens: int
    uniform_decode: bool = False
    """
    False can also be used for an uniform decode batch to dispatch to the 
    cudagraph supporting non-uniform batches.
    """

    @property
    def non_uniform(self) -> "BatchDescriptor":
        """
        Return a non-uniform version of current batch descriptor.
        """
        return BatchDescriptor(self.num_tokens, uniform_decode=False)


def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
                                      max_num_tokens: int,
                                      chunk_idx: int) -> list[int]:
    dp_size = len(num_tokens_across_dp_cpu)

    local_size = [-1] * dp_size
    for i in range(dp_size):
        dp_tokens = num_tokens_across_dp_cpu[i]
        local_size[i] = min(max_num_tokens,
                            dp_tokens - (max_num_tokens * chunk_idx))
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


67
68
@dataclass
class DPMetadata:
69
    max_tokens_across_dp_cpu: torch.Tensor
70
    cu_tokens_across_dp_cpu: torch.Tensor
71
    local_sizes: Optional[list[int]] = None
72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
90
91
92
93
94
95
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":
96
97
98
99
100
101
102
103
104
105
106
107
108

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

109
110
111
112
113
114
115
116
117
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp is None
                or num_tokens_across_dp[dp_rank] == batchsize)
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
118
119
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    @contextmanager
    def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
        number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
        to determine the chunk-wise split.

        `self.local_sizes` is only valid inside the context.

        Args:
            max_chunk_size_per_rank: The max number of tokens each rank is 
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        cu_sizes = self.cu_tokens_across_dp_cpu
        num_tokens_across_dp_cpu = [
            (cu_sizes[i] -
             cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
            for i in range(len(cu_sizes))
        ]
        self.local_sizes = _compute_chunked_local_num_tokens(
            num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
        return self.local_sizes

162

163
164
@dataclass
class ForwardContext:
165
    # copy from vllm_config.compilation_config.static_forward_context
166
    no_compile_layers: dict[str, Any]
167
168
169
170
171
172
173
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
174
175
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
176
177
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
178
179
180
181
182
183
184
185
186
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
    batch_descriptor: Optional[BatchDescriptor] = None

    def __post_init__(self):
        assert self.cudagraph_runtime_mode in [
            CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
187
188
189
190
191
192


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
193
    if envs.VLLM_ENABLE_TBO:
194
195
196
197
198
199
200
        forward_context = get_tbo_forward_context()
        """Get the current forward context."""
        assert forward_context is not None, (
            "Forward context is not set. "
            "Please use `set_forward_context` to set the forward context.")
        return forward_context

201
    """Get the current forward context."""
202
203
204
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
205
206
207
208
    return _forward_context


@contextmanager
209
def set_forward_context(
210
211
212
213
214
215
216
        attn_metadata: Any,
        vllm_config: VllmConfig,
        virtual_engine: int = 0,
        num_tokens: Optional[int] = None,
        num_tokens_across_dp: Optional[torch.Tensor] = None,
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
        batch_descriptor: Optional[BatchDescriptor] = None):
217
    """A context manager that stores the current forward context,
218
219
220
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
221
    global forward_start_time
222
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
223
224
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
225
    dp_metadata: Optional[DPMetadata] = None
226
227
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
228
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
229
230
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)
231

232
233
    global _forward_context
    prev_context = _forward_context
234
    _forward_context = ForwardContext(
235
236
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
237
        virtual_engine=virtual_engine,
238
        attn_metadata=attn_metadata,
239
        dp_metadata=dp_metadata,
240
241
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
242
    )
lizhigong's avatar
lizhigong committed
243
244
    if envs.VLLM_ENABLE_TBO:
        set_tbo_forward_context(_forward_context)
245

246
247
248
    try:
        yield
    finally:
249
250
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
251
            if hasattr(attn_metadata, "num_prefill_tokens"):
252
                # for v0 attention backends
253
254
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
255
256
            else:
                # for v1 attention backends
257
                batchsize = num_tokens
258
259
260
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
261
262
263
264
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
284

285
        _forward_context = prev_context
286
        if envs.VLLM_ENABLE_TBO:
287
            set_tbo_forward_context(_forward_context)
288
289
290
291
292
293
294
295
296
297
298
299
300


_profiling: bool = False

@contextmanager
def set_profilling(profiling):
    global _profiling
    _profiling = profiling


def get_profilling() -> bool:
    global _profiling
    return _profiling