forward_context.py 6.38 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Optional, Union
8

9
import torch
10
import torch.distributed as dist
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
15
from vllm.logger import init_logger

16
17
18
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

19
20
21
22
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
23
forward_start_time: float = 0
24
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
25
batchsize_forward_time: defaultdict = defaultdict(list)
26
27


28
29
@dataclass
class DPMetadata:
30
    max_tokens_across_dp_cpu: torch.Tensor
31
32
    cu_tokens_across_dp_cpu: torch.Tensor

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
    def make(parallel_config: ParallelConfig, attn_metadata: Any,
             num_tokens: int) -> "DPMetadata":

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

        num_tokens_tensor = DPMetadata.num_tokens_across_dp(
            batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

71

72
73
@dataclass
class ForwardContext:
74
    # copy from vllm_config.compilation_config.static_forward_context
75
    no_compile_layers: dict[str, Any]
76
77
78
79
80
81
82
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
83
84
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
85
86
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
87
88
89
90
91
92


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
93
    """Get the current forward context."""
94
95
96
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
97
98
99
100
    return _forward_context


@contextmanager
101
102
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
103
104
                        virtual_engine: int = 0,
                        num_tokens: int = 0):
105
    """A context manager that stores the current forward context,
106
107
108
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
109
    global forward_start_time
110
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
111
112
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
113
    dp_metadata: Optional[DPMetadata] = None
114
    if vllm_config.parallel_config.data_parallel_size > 1:
115
116
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
                                      attn_metadata, num_tokens)
117

118
119
    global _forward_context
    prev_context = _forward_context
120
    _forward_context = ForwardContext(
121
122
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
123
        virtual_engine=virtual_engine,
124
        attn_metadata=attn_metadata,
125
        dp_metadata=dp_metadata)
126

127
128
129
    try:
        yield
    finally:
130
131
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
132
            if hasattr(attn_metadata, "num_prefill_tokens"):
133
                # for v0 attention backends
134
135
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
136
137
            else:
                # for v1 attention backends
138
                batchsize = num_tokens
139
140
141
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
142
143
144
145
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
165

166
        _forward_context = prev_context