forward_context.py 3.81 KB
Newer Older
1
import time
2
from collections import defaultdict
3
from contextlib import contextmanager
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Dict, Optional
6

7
8
import torch

9
import vllm.envs as envs
10
from vllm.config import VllmConfig
11
12
from vllm.logger import init_logger

13
14
15
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

16
17
18
19
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
20
forward_start_time: float = 0
21
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
22
batchsize_forward_time: defaultdict = defaultdict(list)
23
24


25
26
@dataclass
class ForwardContext:
27
28
    # copy from vllm_config.compilation_config.static_forward_context
    attn_layers: Dict[str, Any]
29
    # TODO: extend to support per-layer dynamic forward context
30
31
32
    attn_metadata: "AttentionMetadata"  # set dynamically for each forward pass
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
33
34
35
36
37
38


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
39
    """Get the current forward context."""
40
41
42
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
43
44
45
46
    return _forward_context


@contextmanager
47
48
49
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
                        virtual_engine: int = 0):
50
    """A context manager that stores the current forward context,
51
52
53
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
54
    global forward_start_time
55
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
56
57
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
58
59
    global _forward_context
    prev_context = _forward_context
60
    _forward_context = ForwardContext(
61
62
63
        attn_layers=vllm_config.compilation_config.static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata)
64
65
66
    try:
        yield
    finally:
67
68
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
69
            if hasattr(attn_metadata, "num_prefill_tokens"):
70
                # for v0 attention backends
71
72
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
73
74
            else:
                # for v1 attention backends
75
                batchsize = attn_metadata.num_input_tokens
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            torch.cuda.synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
99
        _forward_context = prev_context