forward_context.py 3.38 KB
Newer Older
1
import time
2
from collections import defaultdict
3
from contextlib import contextmanager
4
5
from dataclasses import dataclass
from typing import Any, Dict, Optional
6

7
8
import torch

9
import vllm.envs as envs
10
from vllm.config import VllmConfig
11
12
13
14
15
16
from vllm.logger import init_logger

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
17
forward_start_time: float = 0
18
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
19
batchsize_forward_time: defaultdict = defaultdict(list)
20
21


22
23
24
25
26
27
28
29
30
31
32
@dataclass
class ForwardContext:
    static_forward_context: Dict[str, Any]
    # TODO: extend to support per-layer dynamic forward context
    dynamic_forward_context: Any


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
33
    """Get the current forward context."""
34
35
36
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
37
38
39
40
    return _forward_context


@contextmanager
41
def set_forward_context(context: Any, vllm_config: VllmConfig):
42
    """A context manager that stores the current forward context,
43
44
45
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
46
47
48
49
    global forward_start_time
    need_to_track_batchsize = track_batchsize and context is not None
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
50
51
    global _forward_context
    prev_context = _forward_context
52
53
54
55
    _forward_context = ForwardContext(
        static_forward_context=vllm_config.compilation_config.
        static_forward_context,
        dynamic_forward_context=context)
56
57
58
    try:
        yield
    finally:
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        global batchsize_counter
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
            if hasattr(context, "num_prefill_tokens"):
                # for v0 attention backends
                batchsize = context.num_prefill_tokens + \
                    context.num_decode_tokens
            else:
                # for v1 attention backends
                batchsize = context.num_input_tokens
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            torch.cuda.synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
92
        _forward_context = prev_context