forward_context.py 6.32 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import defaultdict
5
from contextlib import contextmanager
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Optional
8

9
import torch
10
import torch.distributed as dist
11

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
15
16
17
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
18
19
from vllm.logger import init_logger

20
21
22
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

23
24
25
26
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
27
forward_start_time: float = 0
28
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
29
batchsize_forward_time: defaultdict = defaultdict(list)
30
31


32
33
34
35
36
@dataclass
class DPMetadata:
    cu_tokens_across_dp_cpu: torch.Tensor


37
38
@dataclass
class ForwardContext:
39
    # copy from vllm_config.compilation_config.static_forward_context
40
    no_compile_layers: dict[str, Any]
41
    # TODO: extend to support per-layer dynamic forward context
42
43
44
    attn_metadata: "AttentionMetadata"  # set dynamically for each forward pass
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
45
46
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
47
48
49
50
51
52


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
53
    """Get the current forward context."""
54
55
56
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
57
58
59
60
    return _forward_context


@contextmanager
61
62
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
63
64
                        virtual_engine: int = 0,
                        num_tokens: int = 0):
65
    """A context manager that stores the current forward context,
66
67
68
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
69
    global forward_start_time
70
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
71
72
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
73
    dp_metadata: Optional[DPMetadata] = None
74
75
76
    if vllm_config.parallel_config.data_parallel_size > 1:
        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
77
78
79
80
81
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
82
        else:
83
            # for v1 attention backends or no attn_metadata
84
85
86
87
88
89
90
91
            batchsize = num_tokens
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = batchsize
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
92
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
93
        dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
94

95
96
    global _forward_context
    prev_context = _forward_context
97
    _forward_context = ForwardContext(
98
99
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
100
        virtual_engine=virtual_engine,
101
        attn_metadata=attn_metadata,
102
        dp_metadata=dp_metadata)
103
104
105
106
107
108
109
110
111
112
113

    # KVConnector: trigger (possibly async) load before forward.
    # Each attn layer will block until the reading is complete.
    trigger_kv_transfer = (attn_metadata is not None
                           and has_kv_transfer_group()
                           and is_v1_kv_transfer_group())
    if trigger_kv_transfer:
        kv_connector = get_kv_transfer_group()
        assert isinstance(kv_connector, KVConnectorBase_V1)
        kv_connector.start_load_kv(_forward_context)

114
115
116
    try:
        yield
    finally:
117
118
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
119
            if hasattr(attn_metadata, "num_prefill_tokens"):
120
                # for v0 attention backends
121
122
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
123
124
            else:
                # for v1 attention backends
125
                batchsize = num_tokens
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            torch.cuda.synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
149
150
151
152
153
154
155
156

        # KVConnector: each attn layer triggers (possibly async) save.
        # Ensure all those operations complete before forward() is done.
        if trigger_kv_transfer:
            kv_connector = get_kv_transfer_group()
            assert isinstance(kv_connector, KVConnectorBase_V1)
            kv_connector.wait_for_save()

157
        _forward_context = prev_context