"vllm/vscode:/vscode.git/clone" did not exist on "63e7176f265be43dcc425f5ab4ab45c90234f5c3"
test_cpu_offloading.py 4.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import socket
4
5
import time

6
7
import msgspec
import msgspec.msgpack
8
import pytest
9
10
import zmq
from tqdm import tqdm
11

12
13
14
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
15
16
17
18

CPU_BLOCK_SIZES = [16, 48]


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class MockSubscriber:
    """Helper class to receive and verify published events"""

    def __init__(
        self,
        endpoint: str,
        topic: str,
    ):
        self.ctx = zmq.Context.instance()
        self.topic_bytes = topic.encode("utf-8")

        # Set up subscriber socket
        self.sub = self.ctx.socket(zmq.SUB)
        self.sub.setsockopt(zmq.SUBSCRIBE, self.topic_bytes)
        self.sub.connect(endpoint)

        self.decoder = msgspec.msgpack.Decoder(type=KVEventBatch)

    def get_new_cpu_stored_events(self) -> list[BlockStored]:
        cpu_stored_events: list[BlockStored] = []

        poller = zmq.Poller()
        poller.register(self.sub, zmq.POLLIN)

        timeout = 1000  # 1 second
        while True:
            events = dict(poller.poll(timeout))

            if events.get(self.sub) != zmq.POLLIN:
                return cpu_stored_events

            topic_bytes, _, payload = self.sub.recv_multipart()

            assert topic_bytes == self.topic_bytes

            event_batch = self.decoder.decode(payload)
            assert isinstance(event_batch, KVEventBatch)
            for event in event_batch.events:
                if isinstance(event, BlockStored) and event.medium == "CPU":
                    cpu_stored_events.append(event)
                    timeout = 100

    def close(self):
        """Clean up resources"""
        self.sub.close()


66
67
68
69
70
71
72
73
74
75
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
    """
    Tests OffloadingConnector with CPUOffloadingSpec.
    """

    # configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
    kv_transfer_config = KVTransferConfig(
        kv_connector="OffloadingConnector",
        kv_role="kv_both",
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        kv_connector_extra_config={
            "num_cpu_blocks": 1000,
            "block_size": cpu_block_size,
        },
    )

    port: int
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("0.0.0.0", 0))
        port = s.getsockname()[1]

    events_endpoint = f"tcp://*:{port}"
    kv_events_config = KVEventsConfig(
        enable_kv_cache_events=True,
        publisher="zmq",
        endpoint=events_endpoint,
        topic="test",
93
94
95
96
97
    )

    llm = LLM(
        model="meta-llama/Llama-3.2-1B-Instruct",
        gpu_memory_utilization=0.5,
98
        kv_events_config=kv_events_config,
99
100
101
        kv_transfer_config=kv_transfer_config,
    )

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    sampling_params = SamplingParams(temperature=0, max_tokens=1)

    events_endpoint = events_endpoint.replace("*", "127.0.0.1")
    subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)

    try:
        num_times_cpu_better_than_cold = 0
        num_tests = 10
        total_cold_time = 0.0
        total_gpu_hit_time = 0.0
        total_cpu_hit_time = 0.0
        prompt_token_ids = [0] * 10001
        for i in tqdm(range(num_tests), desc="Running tests"):
            prompt_token_ids[0] = i
            prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]

            # run generation - this should trigger saving KV cache
            start_time = time.time()
            llm.generate(prompts, sampling_params, use_tqdm=False)
            cold_time = time.time() - start_time
            total_cold_time += cold_time

            # run generation again - should hit the GPU prefix cache
            start_time = time.time()
            llm.generate(prompts, sampling_params, use_tqdm=False)
            gpu_hit_time = time.time() - start_time
            total_gpu_hit_time += gpu_hit_time
129

130
131
            # reset prefix cache to avoid GPU hit.
            llm.reset_prefix_cache()
132

133
            assert subscriber.get_new_cpu_stored_events()
134

135
136
137
138
139
            # run generation again - this should trigger loading from CPU
            start_time = time.time()
            llm.generate(prompts, sampling_params, use_tqdm=False)
            cpu_hit_time = time.time() - start_time
            total_cpu_hit_time += cpu_hit_time
140

141
142
            if cpu_hit_time < cold_time:
                num_times_cpu_better_than_cold += 1
143

144
145
146
147
        print("Average times:")
        print(f"    Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
        print(f"    GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
        print(f"    CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
148

149
150
151
152
        assert num_times_cpu_better_than_cold >= 0.8 * num_tests
    finally:
        subscriber.close()
        del llm