forward_context.py 3.85 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Dict, Optional
8

9
10
import torch

11
import vllm.envs as envs
12
from vllm.config import VllmConfig
13
14
from vllm.logger import init_logger

15
16
17
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

18
19
20
21
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
22
forward_start_time: float = 0
23
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
24
batchsize_forward_time: defaultdict = defaultdict(list)
25
26


27
28
@dataclass
class ForwardContext:
29
30
    # copy from vllm_config.compilation_config.static_forward_context
    attn_layers: Dict[str, Any]
31
    # TODO: extend to support per-layer dynamic forward context
32
33
34
    attn_metadata: "AttentionMetadata"  # set dynamically for each forward pass
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
35
36
37
38
39
40


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
41
    """Get the current forward context."""
42
43
44
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
45
46
47
48
    return _forward_context


@contextmanager
49
50
51
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
                        virtual_engine: int = 0):
52
    """A context manager that stores the current forward context,
53
54
55
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
56
    global forward_start_time
57
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
58
59
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
60
61
    global _forward_context
    prev_context = _forward_context
62
    _forward_context = ForwardContext(
63
64
65
        attn_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata)
66
67
68
    try:
        yield
    finally:
69
70
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
71
            if hasattr(attn_metadata, "num_prefill_tokens"):
72
                # for v0 attention backends
73
74
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
75
76
            else:
                # for v1 attention backends
77
                batchsize = attn_metadata.num_input_tokens
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            torch.cuda.synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
101
        _forward_context = prev_context