forward_context.py 7.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
import time
6
from collections import defaultdict
7
from contextlib import contextmanager
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any, Optional, Union
10

11
import torch
12
import torch.distributed as dist
13

14
import vllm.envs as envs
15
from vllm.config import ParallelConfig, VllmConfig
16
from vllm.logger import init_logger
17
from vllm.two_batch_overlap.forward_context import get_tbo_forward_context, set_tbo_forward_context
18

19
20
21
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

22
23
24
25
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
26
forward_start_time: float = 0
27
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
28
batchsize_forward_time: defaultdict = defaultdict(list)
29

30
31
@dataclass
class DPMetadata:
32
    max_tokens_across_dp_cpu: torch.Tensor
33
34
    cu_tokens_across_dp_cpu: torch.Tensor

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
52
53
54
55
56
57
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":
58
59
60
61
62
63
64
65
66
67
68
69
70

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

71
72
73
74
75
76
77
78
79
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp is None
                or num_tokens_across_dp[dp_rank] == batchsize)
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
80
81
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

82

83
84
@dataclass
class ForwardContext:
85
    # copy from vllm_config.compilation_config.static_forward_context
86
    no_compile_layers: dict[str, Any]
87
88
89
90
91
92
93
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
94
95
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
96
97
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
98
    skip_cuda_graphs: bool = False
99
100
101
102
103
104


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
105
    if envs.VLLM_ENABLE_TBO:
106
107
108
109
110
111
112
        forward_context = get_tbo_forward_context()
        """Get the current forward context."""
        assert forward_context is not None, (
            "Forward context is not set. "
            "Please use `set_forward_context` to set the forward context.")
        return forward_context

113
    """Get the current forward context."""
114
115
116
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
117
118
119
120
    return _forward_context


@contextmanager
121
122
123
124
125
126
127
128
def set_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    num_tokens: Optional[int] = None,
    num_tokens_across_dp: Optional[torch.Tensor] = None,
    skip_cuda_graphs: bool = False,
):
129
    """A context manager that stores the current forward context,
130
131
132
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
133
    global forward_start_time
134
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
135
136
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
137
    dp_metadata: Optional[DPMetadata] = None
138
139
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
140
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
141
142
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)
143

144
145
    global _forward_context
    prev_context = _forward_context
146
    _forward_context = ForwardContext(
147
148
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
149
        virtual_engine=virtual_engine,
150
        attn_metadata=attn_metadata,
151
152
153
        dp_metadata=dp_metadata,
        skip_cuda_graphs=skip_cuda_graphs,
    )
lizhigong's avatar
lizhigong committed
154
155
    if envs.VLLM_ENABLE_TBO:
        set_tbo_forward_context(_forward_context)
156

157
158
159
    try:
        yield
    finally:
160
161
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
162
            if hasattr(attn_metadata, "num_prefill_tokens"):
163
                # for v0 attention backends
164
165
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
166
167
            else:
                # for v1 attention backends
168
                batchsize = num_tokens
169
170
171
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
172
173
174
175
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
195

196
        _forward_context = prev_context
197
        if envs.VLLM_ENABLE_TBO:
198
            set_tbo_forward_context(_forward_context)
199
200
201
202
203
204
205
206
207
208
209
210
211


_profiling: bool = False

@contextmanager
def set_profilling(profiling):
    global _profiling
    _profiling = profiling


def get_profilling() -> bool:
    global _profiling
    return _profiling