forward_context.py 7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Optional, Union
8

9
import torch
10
import torch.distributed as dist
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
15
from vllm.logger import init_logger

16
17
18
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

19
20
21
22
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
23
forward_start_time: float = 0
24
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
25
batchsize_forward_time: defaultdict = defaultdict(list)
26
27


28
29
@dataclass
class DPMetadata:
30
    max_tokens_across_dp_cpu: torch.Tensor
31
32
    cu_tokens_across_dp_cpu: torch.Tensor

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
50
51
52
53
54
55
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":
56
57
58
59
60
61
62
63
64
65
66
67
68

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

69
70
71
72
73
74
75
76
77
        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp is None
                or num_tokens_across_dp[dp_rank] == batchsize)
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
78
79
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

80

81
82
@dataclass
class ForwardContext:
83
    # copy from vllm_config.compilation_config.static_forward_context
84
    no_compile_layers: dict[str, Any]
85
86
87
88
89
90
91
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
92
93
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
94
95
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
96
97
98
99
100
101


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
102
    """Get the current forward context."""
103
104
105
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
106
107
108
109
    return _forward_context


@contextmanager
110
111
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
112
                        virtual_engine: int = 0,
113
114
                        num_tokens: Optional[int] = None,
                        num_tokens_across_dp: Optional[torch.Tensor] = None):
115
    """A context manager that stores the current forward context,
116
117
118
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
119
    global forward_start_time
120
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
121
122
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
123
    dp_metadata: Optional[DPMetadata] = None
124
125
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
126
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
127
128
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)
129

130
131
    global _forward_context
    prev_context = _forward_context
132
    _forward_context = ForwardContext(
133
134
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
135
        virtual_engine=virtual_engine,
136
        attn_metadata=attn_metadata,
137
        dp_metadata=dp_metadata)
138

139
140
141
    try:
        yield
    finally:
142
143
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
144
            if hasattr(attn_metadata, "num_prefill_tokens"):
145
                # for v0 attention backends
146
147
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
148
149
            else:
                # for v1 attention backends
150
                batchsize = num_tokens
151
152
153
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
154
155
156
157
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
177

178
        _forward_context = prev_context