"platforms/cuda/vscode:/vscode.git/clone" did not exist on "3d64a10ddac3571b429e867fe37823998f30755b"
forward_context.py 2.36 KB
Newer Older
1
2
import time
from collections import Counter
3
from contextlib import contextmanager
4
5
from dataclasses import dataclass
from typing import Any, Dict, Optional
6

7
import vllm.envs as envs
8
from vllm.config import VllmConfig
9
10
11
12
13
14
15
16
from vllm.logger import init_logger

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
batchsize_counter: Counter = Counter()
last_logging_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
17
18


19
20
21
22
23
24
25
26
27
28
29
@dataclass
class ForwardContext:
    static_forward_context: Dict[str, Any]
    # TODO: extend to support per-layer dynamic forward context
    dynamic_forward_context: Any


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
30
    """Get the current forward context."""
31
32
33
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
34
35
36
37
    return _forward_context


@contextmanager
38
def set_forward_context(context: Any, vllm_config: VllmConfig):
39
    """A context manager that stores the current forward context,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
    global track_batchsize, batchsize_counter
    global last_logging_time, batchsize_logging_interval
    if track_batchsize and context is not None:
        if hasattr(context, "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = context.num_prefill_tokens + context.num_decode_tokens
        else:
            # for v1 attention backends
            batchsize = context.num_input_tokens
        batchsize_counter[batchsize] += 1
        if time.monotonic() - last_logging_time > batchsize_logging_interval:
            last_logging_time = time.monotonic()
            sorted_data = sorted(batchsize_counter.items(),
                                 key=lambda x: x[1],
                                 reverse=True)
            logger.info("Batchsize distribution (batchsize, count): %s",
                        sorted_data)
60
61
    global _forward_context
    prev_context = _forward_context
62
63
64
65
    _forward_context = ForwardContext(
        static_forward_context=vllm_config.compilation_config.
        static_forward_context,
        dynamic_forward_context=context)
66
67
68
69
    try:
        yield
    finally:
        _forward_context = prev_context