cpu_offload_lmcache.py 4.57 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
4
5
6
7
8
9
10
11
12
13
14
15
16
with LMCache in vLLM v1 or v0.

Usage:

    Specify vLLM version

    -v v0 : Use LMCacheConnector
            model = mistralai/Mistral-7B-Instruct-v0.2
            (Includes enable_chunked_prefill = True)

    -v v1 : Use LMCacheConnectorV1 (default)
            model = meta-llama/Meta-Llama-3.1-8B-Instruct
            (Without enable_chunked_prefill)
17

18
19
20
21
Note that `lmcache` is needed to run this example.
Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
22
"""
23
import argparse
24
import contextlib
25
26
import os
import time
27
from dataclasses import asdict
28
29
30
31
32
33

from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
34
from vllm.engine.arg_utils import EngineArgs
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49

def setup_environment_variables():
    # LMCache-related environment variables
    # Use experimental features in LMCache
    os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
    # LMCache is set to use 256 tokens per chunk
    os.environ["LMCACHE_CHUNK_SIZE"] = "256"
    # Enable local CPU backend in LMCache
    os.environ["LMCACHE_LOCAL_CPU"] = "True"
    # Set local CPU memory limit to 5.0 GB
    os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"


@contextlib.contextmanager
50
51
52
53
54
55
def build_llm_with_lmcache(lmcache_connector: str, model: str,
                           vllm_version: str):
    ktc = KVTransferConfig(
        kv_connector=lmcache_connector,
        kv_role="kv_both",
    )
56
57
    # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
    # memory. Reduce the value if your GPU has less memory.
58
    # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    if vllm_version == "v0":
        llm_args = EngineArgs(
            model=model,
            kv_transfer_config=ktc,
            max_model_len=8000,
            gpu_memory_utilization=0.8,
            enable_chunked_prefill=True,  # Only in v0
        )
    else:
        llm_args = EngineArgs(
            model=model,
            kv_transfer_config=ktc,
            max_model_len=8000,
            gpu_memory_utilization=0.8,
        )

    llm = LLM(**asdict(llm_args))
76
77
78
79
80
81
82
83
84
85
86
87
88
    try:
        yield llm
    finally:
        # Clean up lmcache backend
        LMCacheEngineBuilder.destroy(ENGINE_NAME)


def print_output(
    llm: LLM,
    prompt: list[str],
    sampling_params: SamplingParams,
    req_str: str,
):
89
90
91
    # Should be able to see logs like the following:
    # `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
    # This indicates that the KV cache has been stored in LMCache.
92
93
94
95
96
97
98
99
100
101
102
    start = time.time()
    outputs = llm.generate(prompt, sampling_params)
    print("-" * 50)
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")
    print(f"Generation took {time.time() - start:.2f} seconds, "
          f"{req_str} request done.")
    print("-" * 50)


103
104
105
106
107
108
109
110
111
112
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-v",
                        "--version",
                        choices=["v0", "v1"],
                        default="v1",
                        help="Specify vLLM version (default: v1)")
    return parser.parse_args()


113
def main():
114
115
116
117
118
119
120
121
122
    args = parse_args()

    if args.version == "v0":
        lmcache_connector = "LMCacheConnector"
        model = "mistralai/Mistral-7B-Instruct-v0.2"
    else:
        lmcache_connector = "LMCacheConnectorV1"
        model = "meta-llama/Meta-Llama-3.1-8B-Instruct"

123
124
    setup_environment_variables()

125
    with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

        # This example script runs two requests with a shared prefix.
        # Define the shared prompt and specific prompts
        shared_prompt = "Hello, how are you?" * 1000
        first_prompt = [
            shared_prompt + "Hello, my name is",
        ]
        second_prompt = [
            shared_prompt + "Tell me a very long story",
        ]

        sampling_params = SamplingParams(temperature=0,
                                         top_p=0.95,
                                         max_tokens=10)

        # Print the first output
        print_output(llm, first_prompt, sampling_params, "first")

        time.sleep(1)

        # print the second output
        print_output(llm, second_prompt, sampling_params, "second")


if __name__ == "__main__":
    main()