forward_context.py 9.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Optional, Union
9

10
import torch
11
import torch.distributed as dist
12

13
import vllm.envs as envs
14
from vllm.config import ParallelConfig, VllmConfig
15
16
from vllm.logger import init_logger

17
18
19
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

20
21
22
23
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
24
forward_start_time: float = 0
25
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
26
batchsize_forward_time: defaultdict = defaultdict(list)
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
                                      max_num_tokens: int,
                                      chunk_idx: int) -> list[int]:
    dp_size = len(num_tokens_across_dp_cpu)

    local_size = [-1] * dp_size
    for i in range(dp_size):
        dp_tokens = num_tokens_across_dp_cpu[i]
        local_size[i] = min(max_num_tokens,
                            dp_tokens - (max_num_tokens * chunk_idx))
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size


44
45
@dataclass
class DPMetadata:
46
    max_tokens_across_dp_cpu: torch.Tensor
47
    cu_tokens_across_dp_cpu: torch.Tensor
48
    local_sizes: Optional[list[int]] = None
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
67
68
69
70
71
72
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":
73
74
75
76
77
78
79
80
81
82
83
84
85

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

86
87
88
89
90
91
92
93
94
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp is None
                or num_tokens_across_dp[dp_rank] == batchsize)
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
95
96
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    @contextmanager
    def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
        number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
        to determine the chunk-wise split.

        `self.local_sizes` is only valid inside the context.

        Args:
            max_chunk_size_per_rank: The max number of tokens each rank is 
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        cu_sizes = self.cu_tokens_across_dp_cpu
        num_tokens_across_dp_cpu = [
            (cu_sizes[i] -
             cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
            for i in range(len(cu_sizes))
        ]
        self.local_sizes = _compute_chunked_local_num_tokens(
            num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
        return self.local_sizes

139

140
141
@dataclass
class ForwardContext:
142
    # copy from vllm_config.compilation_config.static_forward_context
143
    no_compile_layers: dict[str, Any]
144
145
146
147
148
149
150
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
151
152
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
153
154
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
155
    skip_cuda_graphs: bool = False
156
157
158
159
160
161


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
162
    """Get the current forward context."""
163
164
165
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
166
167
168
169
    return _forward_context


@contextmanager
170
171
172
173
174
175
176
177
def set_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    num_tokens: Optional[int] = None,
    num_tokens_across_dp: Optional[torch.Tensor] = None,
    skip_cuda_graphs: bool = False,
):
178
    """A context manager that stores the current forward context,
179
180
181
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
182
    global forward_start_time
183
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
184
185
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
186
    dp_metadata: Optional[DPMetadata] = None
187
188
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
189
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
190
191
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)
192

193
194
    global _forward_context
    prev_context = _forward_context
195
    _forward_context = ForwardContext(
196
197
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
198
        virtual_engine=virtual_engine,
199
        attn_metadata=attn_metadata,
200
201
202
        dp_metadata=dp_metadata,
        skip_cuda_graphs=skip_cuda_graphs,
    )
203

204
205
206
    try:
        yield
    finally:
207
208
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
209
            if hasattr(attn_metadata, "num_prefill_tokens"):
210
                # for v0 attention backends
211
212
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
213
214
            else:
                # for v1 attention backends
215
                batchsize = num_tokens
216
217
218
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
219
220
221
222
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
242

243
        _forward_context = prev_context