kv_cache_sharing_lmcache_v1.py 4.22 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
"""
3
This file demonstrates the example usage of remote KV cache sharing
4
with LMCache.
5
We will launch 2 vllm instances, and launch an additional LMCache server.
6
KV cache is transferred in the following manner: 
7
8
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
9

10
Note that lmcache needs to be installed to run this example.
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
from multiprocessing import Event, Process

from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

# LMCache-related environment variables
# The port to start LMCache server
port = 8100
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Disable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "False"
# Set local CPU memory buffer limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# Set the remote URL for LMCache server
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"

Reid's avatar
Reid committed
41
42
43
44
prompts = [
    "Hello, how are you?" * 1000,
]

45

46
47
def run_store(store_done, prompts):
    # We use GPU 0 for KV cache store process.
48
49
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

50
    sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
51
52

    ktc = KVTransferConfig.from_cli(
53
        '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
54
55
56
57
58
59
60
61
62
63
64
65
    # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
    # memory. Reduce the value if your GPU has less memory.
    llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
              kv_transfer_config=ktc,
              max_model_len=8000,
              gpu_memory_utilization=0.8,
              enforce_eager=True)

    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")
66
67
    print("KV cache store is finished.")
    store_done.set()
68
69
70
71
72

    # Clean up lmcache backend
    LMCacheEngineBuilder.destroy(ENGINE_NAME)


73
74
def run_retrieve(store_done, prompts, timeout=1):
    # We use GPU 1 for KV cache retrieve process.
75
76
77
78
79
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

    ktc = KVTransferConfig.from_cli(
80
        '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
81
82
83
84
85
86
87
88
    # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
    # of memory. Reduce the value if your GPU has less memory.
    llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
              kv_transfer_config=ktc,
              max_model_len=8000,
              gpu_memory_utilization=0.8,
              enforce_eager=True)

89
90
    print("Waiting for KV cache store to finish...")
    store_done.wait()
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    time.sleep(timeout)

    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")

    # Clean up lmcache backend
    LMCacheEngineBuilder.destroy(ENGINE_NAME)


def run_lmcache_server(port):
    server_proc = subprocess.Popen([
        "python", "-m", "lmcache.experimental.server", "localhost",
        str(port)
    ])
    return server_proc


Reid's avatar
Reid committed
110
def main():
111
112
113
    store_done = Event()
    store_process = Process(target=run_store, args=(store_done, prompts))
    retrieve_process = Process(target=run_retrieve, args=(store_done, prompts))
114
115
    lmcache_server_process = run_lmcache_server(port)

116
117
    # Start KV cache store process
    store_process.start()
118

119
120
    # Start KV cache retrieve process
    retrieve_process.start()
121
122

    # Clean up the processes
123
124
    store_process.join()
    retrieve_process.terminate()
125
126
    lmcache_server_process.terminate()
    lmcache_server_process.wait()
Reid's avatar
Reid committed
127
128
129
130


if __name__ == "__main__":
    main()