test_cpu_offloading.py 7.71 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import socket
4
5
import time

6
7
import msgspec
import msgspec.msgpack
8
import pytest
9
10
import zmq
from tqdm import tqdm
11

12
13
14
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
15
from vllm.platforms import current_platform
16

17
CPU_BLOCK_SIZES = [48]
18
ATTN_BACKENDS = []
19
20

if current_platform.is_cuda():
21
22
23
    ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER", "TRITON_ATTN"]
elif current_platform.is_rocm():
    ATTN_BACKENDS = ["TRITON_ATTN"]
24

25
26
27
28
29
30
31
32
33
34
35
# Maximum time (seconds) to wait for the async CPU offload transfer
# to complete before giving up.
_RESET_CACHE_TIMEOUT = 30 if current_platform.is_rocm() else 10

# ZMQ poll timeout (ms) for the first event.
_FIRST_EVENT_POLL_MS = 10_000 if current_platform.is_rocm() else 1000

# Hard ceiling (seconds) on how long get_new_cpu_stored_events may loop,
# to prevent hangs if non-CPU events keep arriving indefinitely.
_EVENT_DRAIN_TIMEOUT = 60

36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class MockSubscriber:
    """Helper class to receive and verify published events"""

    def __init__(
        self,
        endpoint: str,
        topic: str,
    ):
        self.ctx = zmq.Context.instance()
        self.topic_bytes = topic.encode("utf-8")

        # Set up subscriber socket
        self.sub = self.ctx.socket(zmq.SUB)
        self.sub.setsockopt(zmq.SUBSCRIBE, self.topic_bytes)
        self.sub.connect(endpoint)

        self.decoder = msgspec.msgpack.Decoder(type=KVEventBatch)

    def get_new_cpu_stored_events(self) -> list[BlockStored]:
        cpu_stored_events: list[BlockStored] = []

        poller = zmq.Poller()
        poller.register(self.sub, zmq.POLLIN)

61
62
63
64
        poll_ms = _FIRST_EVENT_POLL_MS
        deadline = time.monotonic() + _EVENT_DRAIN_TIMEOUT
        while time.monotonic() < deadline:
            events = dict(poller.poll(poll_ms))
65
66
67
68
69
70
71
72
73
74
75
76
77

            if events.get(self.sub) != zmq.POLLIN:
                return cpu_stored_events

            topic_bytes, _, payload = self.sub.recv_multipart()

            assert topic_bytes == self.topic_bytes

            event_batch = self.decoder.decode(payload)
            assert isinstance(event_batch, KVEventBatch)
            for event in event_batch.events:
                if isinstance(event, BlockStored) and event.medium == "CPU":
                    cpu_stored_events.append(event)
78
79
80
                    poll_ms = 100

        return cpu_stored_events
81
82
83
84
85
86

    def close(self):
        """Clean up resources"""
        self.sub.close()


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def _wait_for_prefix_cache_reset(llm: LLM) -> None:
    """Wait for async offload transfers to finish so prefix cache can reset.

    The GPU-to-CPU offload runs on a CUDA stream asynchronously.  While blocks
    are still held by the offload worker, ``reset_prefix_cache`` returns
    ``False``.  Retry with a short sleep until it succeeds or we time out.
    """
    deadline = time.monotonic() + _RESET_CACHE_TIMEOUT
    while not llm.reset_prefix_cache():
        if time.monotonic() > deadline:
            raise TimeoutError(
                "reset_prefix_cache did not succeed within "
                f"{_RESET_CACHE_TIMEOUT}s - async offload may be stuck"
            )
        time.sleep(0.1)


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def _latency_test(llm: LLM, subscriber: MockSubscriber):
    sampling_params = SamplingParams(max_tokens=1)

    num_times_cpu_better_than_cold = 0
    num_tests = 10
    total_cold_time = 0.0
    total_gpu_hit_time = 0.0
    total_cpu_hit_time = 0.0
    prompt_token_ids = [0] * 10001
    for i in tqdm(range(num_tests), desc="Running tests"):
        prompt_token_ids[0] = i
        prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]

        # run generation - this should trigger saving KV cache
        start_time = time.time()
        llm.generate(prompts, sampling_params, use_tqdm=False)
        cold_time = time.time() - start_time
        total_cold_time += cold_time

        # run generation again - should hit the GPU prefix cache
        start_time = time.time()
        llm.generate(prompts, sampling_params, use_tqdm=False)
        gpu_hit_time = time.time() - start_time
        total_gpu_hit_time += gpu_hit_time

129
130
131
        # Wait for the async CPU offload to finish, then reset prefix cache
        # so the next generate() must reload from CPU rather than GPU.
        _wait_for_prefix_cache_reset(llm)
132

133
134
135
136
137
138
        # Verify CPU stored events arrived (offload is done before we
        # attempt to load from CPU).
        assert subscriber.get_new_cpu_stored_events(), (
            f"No CPU stored events received on iteration {i}; "
            "async offload may not have completed in time"
        )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

        # run generation again - this should trigger loading from CPU
        start_time = time.time()
        llm.generate(prompts, sampling_params, use_tqdm=False)
        cpu_hit_time = time.time() - start_time
        total_cpu_hit_time += cpu_hit_time

        if cpu_hit_time < cold_time:
            num_times_cpu_better_than_cold += 1

    print("Average times:")
    print(f"    Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
    print(f"    GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
    print(f"    CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")

    assert num_times_cpu_better_than_cold >= 0.8 * num_tests


def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
    sampling_params = SamplingParams(max_tokens=1)
    cpu_block_size = (
        llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[
            "block_size"
        ]
    )

    subscriber.get_new_cpu_stored_events()

    # prepend prompt to be cpu block aligned
    prompt = "Let's count to 10. One, two, three, four,"
    while (
        len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size
        != 0
    ):
        prompt = ". " + prompt

    assert subscriber.get_new_cpu_stored_events()

    test_count = 100
    success_count = 0
    for i in range(test_count):
        if (
            llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text
            == " five"
        ):
            success_count += 1

    assert success_count >= 0.5 * test_count


189
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
190
191
@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
192
193
194
195
196
197
198
199
    """
    Tests OffloadingConnector with CPUOffloadingSpec.
    """

    # configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
    kv_transfer_config = KVTransferConfig(
        kv_connector="OffloadingConnector",
        kv_role="kv_both",
200
        kv_connector_extra_config={
201
            "cpu_bytes_to_use": 500 << 20,
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            "block_size": cpu_block_size,
        },
    )

    port: int
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("0.0.0.0", 0))
        port = s.getsockname()[1]

    events_endpoint = f"tcp://*:{port}"
    kv_events_config = KVEventsConfig(
        enable_kv_cache_events=True,
        publisher="zmq",
        endpoint=events_endpoint,
        topic="test",
217
218
    )

219
220
221
222
223
224
    llm = LLM(
        model="meta-llama/Llama-3.2-1B-Instruct",
        gpu_memory_utilization=0.5,
        kv_events_config=kv_events_config,
        kv_transfer_config=kv_transfer_config,
        attention_config={"backend": attn_backend},
225
226
        # ROCm: batch size 1 to reduce variability
        **({"max_num_seqs": 1} if current_platform.is_rocm() else {}),
227
    )
228
229
230
231
232

    events_endpoint = events_endpoint.replace("*", "127.0.0.1")
    subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)

    try:
233
234
        _latency_test(llm, subscriber)
        _accuracy_test(llm, subscriber)
235
236
237
    finally:
        subscriber.close()
        del llm