forward_context.py 1.12 KB
Newer Older
1
from contextlib import contextmanager
2
3
from dataclasses import dataclass
from typing import Any, Dict, Optional
4

5
from vllm.config import VllmConfig
6
7


8
9
10
11
12
13
14
15
16
17
18
@dataclass
class ForwardContext:
    static_forward_context: Dict[str, Any]
    # TODO: extend to support per-layer dynamic forward context
    dynamic_forward_context: Any


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
19
    """Get the current forward context."""
20
21
22
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
23
24
25
26
    return _forward_context


@contextmanager
27
def set_forward_context(context: Any, vllm_config: VllmConfig):
28
29
30
31
    """A context manager that stores the current forward context,
    can be attention metadata, etc."""
    global _forward_context
    prev_context = _forward_context
32
33
34
35
    _forward_context = ForwardContext(
        static_forward_context=vllm_config.compilation_config.
        static_forward_context,
        dynamic_forward_context=context)
36
37
38
39
    try:
        yield
    finally:
        _forward_context = prev_context