forward_context.py 13.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, NamedTuple, Union
9

10
11
import torch

12
import vllm.envs as envs
13
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
14
from vllm.logger import init_logger
15
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
16
from vllm.v1.worker.ubatch_utils import UBatchSlices
17

18
19
20
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

21
22
23
24
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
25
forward_start_time: float = 0
26
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
27
batchsize_forward_time: defaultdict = defaultdict(list)
28
29


30
31
32
33
34
35
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
36

37
38
39
40
41
42
    num_tokens: int
    uniform_decode: bool = False
    """
    False can also be used for an uniform decode batch to dispatch to the 
    cudagraph supporting non-uniform batches.
    """
43
44
45
46
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
47
48
49
50
51
52

    @property
    def non_uniform(self) -> "BatchDescriptor":
        """
        Return a non-uniform version of current batch descriptor.
        """
53
54
55
        return BatchDescriptor(
            self.num_tokens, uniform_decode=False, has_lora=self.has_lora
        )
56
57


58
59
60
61
62
63
def _compute_sp_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
) -> list[int]:
    sp_tokens = (
        num_tokens_across_dp_cpu + sequence_parallel_size - 1
    ) // sequence_parallel_size
64
65
66
67
68

    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
    return sp_tokens.tolist()


69
70
71
72
73
74
75
def _compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: torch.Tensor,
    sequence_parallel_size: int,
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]:
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size)
76
77
78
79
80
    sp_size = len(sp_tokens)

    local_size = [-1] * sp_size
    for i in range(sp_size):
        # Take into account sharding if MoE activation is sequence parallel.
81
        local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx))
82
83
84
85
86
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


87
88
@dataclass
class DPMetadata:
89
    max_tokens_across_dp_cpu: torch.Tensor
90
91
92
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
93
    local_sizes: list[int] | None = None
94

95
    @staticmethod
96
    def make(
97
98
        parallel_config: ParallelConfig,
        num_tokens: int,
99
        num_tokens_across_dp_cpu: torch.Tensor,
100
    ) -> "DPMetadata":
101
        assert num_tokens_across_dp_cpu is not None
102
103
        assert parallel_config.data_parallel_size > 1
        dp_rank = parallel_config.data_parallel_rank
104
        batchsize = num_tokens
105

106
107
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
108
109
110
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
111
112
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
113

114
    @contextmanager
115
116
117
    def chunked_sizes(
        self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int
    ):
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        `self.local_sizes` is only valid inside the context.

        Args:
134
135
136
137
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
                                    we use SP between the layers to avoid
                                    redundant ops. We need this value to
                                    compute the chunked sizes.
138
            max_chunk_size_per_rank: The max number of tokens each rank is
139
140
141
142
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        self.local_sizes = _compute_chunked_local_num_tokens(
143
144
145
146
147
            self.num_tokens_across_dp_cpu,
            sequence_parallel_size,
            max_chunk_size_per_rank,
            chunk_idx,
        )
148
149
150
151
152
153
154
155
156
157
158
159
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
        Context mamager for setting self.local_sizes. Same as self.chunked_sizes
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
160
161
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
162
163
164
165
166
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

167
    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
168
        assert self.local_sizes is not None
169
170
        return self.local_sizes

171
172
173
174
175
176
177
178
179
180
181
    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cummulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

182

183
184
@dataclass
class ForwardContext:
185
    # copy from vllm_config.compilation_config.static_forward_context
186
    no_compile_layers: dict[str, Any]
187
188
189
190
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
191
192
193
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
194
    """
195
196
197
198
199
    attn_metadata: Union[
        "AttentionMetadata",
        dict[str, "AttentionMetadata"],
        list[dict[str, "AttentionMetadata"]],
    ]
200
201
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
202
    # set dynamically for each forward pass
203
    dp_metadata: DPMetadata | None = None
204
205
206
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
207
    batch_descriptor: BatchDescriptor | None = None
208

209
    ubatch_slices: UBatchSlices | None = None
210

211
    def __post_init__(self):
212
        assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
213
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
214
        )
215
216


217
_forward_context: ForwardContext | None = None
218
219
220


def get_forward_context() -> ForwardContext:
221
    """Get the current forward context."""
222
223
    assert _forward_context is not None, (
        "Forward context is not set. "
224
225
        "Please use `set_forward_context` to set the forward context."
    )
226
227
228
    return _forward_context


229
def create_forward_context(
230
231
232
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
233
    dp_metadata: DPMetadata | None = None,
234
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
235
236
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
237
238
239
240
241
242
243
244
245
246
):
    return ForwardContext(
        no_compile_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )
247
248
249


@contextmanager
250
def override_forward_context(forward_context: ForwardContext | None):
251
252
253
254
255
256
257
258
259
260
261
262
263
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context


264
@contextmanager
265
def set_forward_context(
266
267
268
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
269
270
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
271
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
272
273
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
274
):
275
    """A context manager that stores the current forward context,
276
277
278
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
279
    global forward_start_time
280
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
281
282
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
283

284
    dp_metadata: DPMetadata | None = None
285
    if vllm_config.parallel_config.data_parallel_size > 1 and (
286
287
        attn_metadata is not None or num_tokens is not None
    ):
288
289
290
291
292
293
294
295
296
297
298
299
300
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
            _, num_tokens_across_dp = coordinate_batch_across_dp(
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
                allow_dp_padding=False,
            )
            assert num_tokens_across_dp is not None
301
        dp_metadata = DPMetadata.make(
302
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
303
304
        )

305
306
307
308
309
310
    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

311
312
313
314
315
316
317
318
319
    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        virtual_engine,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
    )
320

321
    try:
322
323
        with override_forward_context(forward_context):
            yield
324
    finally:
325
326
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
327
            if hasattr(attn_metadata, "num_prefill_tokens"):
328
                # for v0 attention backends
329
330
331
                batchsize = (
                    attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
                )
332
333
            else:
                # for v1 attention backends
334
                batchsize = num_tokens
335
336
337
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
338
            from vllm.platforms import current_platform
339

340
341
342
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
343
344
            now = time.perf_counter()
            # time measurement is in milliseconds
345
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
346
347
348
349
350
351
352
353
354
355
356
357
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
358
359
360
361
362
363
364
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )