forward_context.py 5.31 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Optional
8

9
import torch
10
import torch.distributed as dist
11

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
15
from vllm.logger import init_logger

16
17
18
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

19
20
21
22
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
23
forward_start_time: float = 0
24
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
25
batchsize_forward_time: defaultdict = defaultdict(list)
26
27


28
29
30
31
32
@dataclass
class DPMetadata:
    cu_tokens_across_dp_cpu: torch.Tensor


33
34
@dataclass
class ForwardContext:
35
    # copy from vllm_config.compilation_config.static_forward_context
36
    no_compile_layers: dict[str, Any]
37
    # TODO: extend to support per-layer dynamic forward context
38
39
40
    attn_metadata: "AttentionMetadata"  # set dynamically for each forward pass
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
41
42
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
43
44
45
46
47
48


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
49
    """Get the current forward context."""
50
51
52
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
53
54
55
56
    return _forward_context


@contextmanager
57
58
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
59
60
                        virtual_engine: int = 0,
                        num_tokens: int = 0):
61
    """A context manager that stores the current forward context,
62
63
64
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
65
    global forward_start_time
66
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
67
68
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
69
    dp_metadata: Optional[DPMetadata] = None
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    if vllm_config.parallel_config.data_parallel_size > 1:
        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        if attn_metadata is not None:
            if hasattr(attn_metadata, "num_prefill_tokens"):
                # for v0 attention backends
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
            else:
                # for v1 attention backends
                batchsize = attn_metadata.num_input_tokens
        else:
            batchsize = num_tokens
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = batchsize
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
90
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
91
        dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
92

93
94
    global _forward_context
    prev_context = _forward_context
95
    _forward_context = ForwardContext(
96
97
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
98
        virtual_engine=virtual_engine,
99
        attn_metadata=attn_metadata,
100
        dp_metadata=dp_metadata)
101
102
103
    try:
        yield
    finally:
104
105
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
106
            if hasattr(attn_metadata, "num_prefill_tokens"):
107
                # for v0 attention backends
108
109
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
110
111
            else:
                # for v1 attention backends
112
                batchsize = attn_metadata.num_input_tokens
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            torch.cuda.synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
136
        _forward_context = prev_context