forward_context.py 5.71 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Optional, Union
8

9
import torch
10
import torch.distributed as dist
11

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
15
from vllm.logger import init_logger

16
17
18
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

19
20
21
22
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
23
forward_start_time: float = 0
24
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
25
batchsize_forward_time: defaultdict = defaultdict(list)
26
27


28
29
@dataclass
class DPMetadata:
30
    max_tokens_across_dp_cpu: torch.Tensor
31
32
33
    cu_tokens_across_dp_cpu: torch.Tensor


34
35
@dataclass
class ForwardContext:
36
    # copy from vllm_config.compilation_config.static_forward_context
37
    no_compile_layers: dict[str, Any]
38
39
40
41
42
43
44
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
45
46
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
47
48
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
49
50
51
52
53
54


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
55
    """Get the current forward context."""
56
57
58
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
59
60
61
62
    return _forward_context


@contextmanager
63
64
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
65
66
                        virtual_engine: int = 0,
                        num_tokens: int = 0):
67
    """A context manager that stores the current forward context,
68
69
70
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
71
    global forward_start_time
72
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
73
74
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
75
    dp_metadata: Optional[DPMetadata] = None
76
77
78
    if vllm_config.parallel_config.data_parallel_size > 1:
        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
79
80
81
82
83
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
84
        else:
85
            # for v1 attention backends or no attn_metadata
86
87
88
89
90
91
92
93
            batchsize = num_tokens
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = batchsize
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
94
        max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
95
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
96
97
        dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
                                 cu_tokens_across_dp_cpu)
98

99
100
    global _forward_context
    prev_context = _forward_context
101
    _forward_context = ForwardContext(
102
103
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
104
        virtual_engine=virtual_engine,
105
        attn_metadata=attn_metadata,
106
        dp_metadata=dp_metadata)
107

108
109
110
    try:
        yield
    finally:
111
112
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
113
            if hasattr(attn_metadata, "num_prefill_tokens"):
114
                # for v0 attention backends
115
116
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
117
118
            else:
                # for v1 attention backends
119
                batchsize = num_tokens
120
121
122
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
123
124
125
126
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
146

147
        _forward_context = prev_context