forward_context.py 7.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Optional, Union
9

10
import torch
11
import torch.distributed as dist
12

13
import vllm.envs as envs
14
from vllm.config import ParallelConfig, VllmConfig
15
16
from vllm.logger import init_logger

17
18
19
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

20
21
22
23
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
24
forward_start_time: float = 0
25
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
26
batchsize_forward_time: defaultdict = defaultdict(list)
27
28


29
30
@dataclass
class DPMetadata:
31
    max_tokens_across_dp_cpu: torch.Tensor
32
33
    cu_tokens_across_dp_cpu: torch.Tensor

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
51
52
53
54
55
56
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":
57
58
59
60
61
62
63
64
65
66
67
68
69

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

70
71
72
73
74
75
76
77
78
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp is None
                or num_tokens_across_dp[dp_rank] == batchsize)
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
79
80
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

81

82
83
@dataclass
class ForwardContext:
84
    # copy from vllm_config.compilation_config.static_forward_context
85
    no_compile_layers: dict[str, Any]
86
87
88
89
90
91
92
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
93
94
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
95
96
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
97
    skip_cuda_graphs: bool = False
98
99
100
101
102
103


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
104
    """Get the current forward context."""
105
106
107
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
108
109
110
111
    return _forward_context


@contextmanager
112
113
114
115
116
117
118
119
def set_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    num_tokens: Optional[int] = None,
    num_tokens_across_dp: Optional[torch.Tensor] = None,
    skip_cuda_graphs: bool = False,
):
120
    """A context manager that stores the current forward context,
121
122
123
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
124
    global forward_start_time
125
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
126
127
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
128
    dp_metadata: Optional[DPMetadata] = None
129
130
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
131
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
132
133
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)
134

135
136
    global _forward_context
    prev_context = _forward_context
137
    _forward_context = ForwardContext(
138
139
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
140
        virtual_engine=virtual_engine,
141
        attn_metadata=attn_metadata,
142
143
144
        dp_metadata=dp_metadata,
        skip_cuda_graphs=skip_cuda_graphs,
    )
145

146
147
148
    try:
        yield
    finally:
149
150
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
151
            if hasattr(attn_metadata, "num_prefill_tokens"):
152
                # for v0 attention backends
153
154
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
155
156
            else:
                # for v1 attention backends
157
                batchsize = num_tokens
158
159
160
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
161
162
163
164
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
184

185
        _forward_context = prev_context