forward_context.py 6.96 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import time
5
from collections import defaultdict
6
from contextlib import contextmanager
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Optional
9

10
import torch
11
import torch.distributed as dist
12

13
import vllm.envs as envs
14
from vllm.config import VllmConfig
15
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
19
from vllm.logger import init_logger
20
from vllm.two_batch_overlap.forward_context import get_tbo_forward_context, set_tbo_forward_context
21

22
23
24
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionMetadata

25
26
27
28
logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
last_logging_time: float = 0
29
forward_start_time: float = 0
30
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
31
batchsize_forward_time: defaultdict = defaultdict(list)
32

33
34
35
36
37
@dataclass
class DPMetadata:
    cu_tokens_across_dp_cpu: torch.Tensor


38
39
@dataclass
class ForwardContext:
40
    # copy from vllm_config.compilation_config.static_forward_context
41
    no_compile_layers: dict[str, Any]
42
    # TODO: extend to support per-layer dynamic forward context
43
44
45
    attn_metadata: "AttentionMetadata"  # set dynamically for each forward pass
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
46
47
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
48
49
50
51
52
53


_forward_context: Optional[ForwardContext] = None


def get_forward_context() -> ForwardContext:
54
    if envs.VLLM_ENABLE_TBO:
55
56
57
58
59
60
61
        forward_context = get_tbo_forward_context()
        """Get the current forward context."""
        assert forward_context is not None, (
            "Forward context is not set. "
            "Please use `set_forward_context` to set the forward context.")
        return forward_context

62
    """Get the current forward context."""
63
64
65
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
66
67
68
69
    return _forward_context


@contextmanager
70
71
def set_forward_context(attn_metadata: Any,
                        vllm_config: VllmConfig,
72
73
                        virtual_engine: int = 0,
                        num_tokens: int = 0):
74
    """A context manager that stores the current forward context,
75
76
77
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
78
    global forward_start_time
79
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
80
81
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
82
    dp_metadata: Optional[DPMetadata] = None
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    if vllm_config.parallel_config.data_parallel_size > 1:
        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        if attn_metadata is not None:
            if hasattr(attn_metadata, "num_prefill_tokens"):
                # for v0 attention backends
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
            else:
                # for v1 attention backends
                batchsize = attn_metadata.num_input_tokens
        else:
            batchsize = num_tokens
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = batchsize
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
103
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
104
        dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
105

106
107
    global _forward_context
    prev_context = _forward_context
108
    _forward_context = ForwardContext(
109
110
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
111
        virtual_engine=virtual_engine,
112
        attn_metadata=attn_metadata,
113
        dp_metadata=dp_metadata)
114
115
116
117
118
119
120
121
122
123

    # KVConnector: trigger (possibly async) load before forward.
    # Each attn layer will block until the reading is complete.
    trigger_kv_transfer = (attn_metadata is not None
                           and has_kv_transfer_group()
                           and is_v1_kv_transfer_group())
    if trigger_kv_transfer:
        kv_connector = get_kv_transfer_group()
        assert isinstance(kv_connector, KVConnectorBase_V1)
        kv_connector.start_load_kv(_forward_context)
124
    if envs.VLLM_ENABLE_TBO:
125
        set_tbo_forward_context(_forward_context)
126
127
128
    try:
        yield
    finally:
129
130
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
131
            if hasattr(attn_metadata, "num_prefill_tokens"):
132
                # for v0 attention backends
133
134
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
135
136
            else:
                # for v1 attention backends
137
                batchsize = attn_metadata.num_input_tokens
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            torch.cuda.synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)
161
162
163
164
165
166
167
168

        # KVConnector: each attn layer triggers (possibly async) save.
        # Ensure all those operations complete before forward() is done.
        if trigger_kv_transfer:
            kv_connector = get_kv_transfer_group()
            assert isinstance(kv_connector, KVConnectorBase_V1)
            kv_connector.wait_for_save()

169
        _forward_context = prev_context
170
        if envs.VLLM_ENABLE_TBO:
171
            set_tbo_forward_context(_forward_context)