cpu_offload_lmcache.py 4.5 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
4
5
6
7
8
9
10
11
12
13
14
15
16
with LMCache in vLLM v1 or v0.

Usage:

    Specify vLLM version

    -v v0 : Use LMCacheConnector
            model = mistralai/Mistral-7B-Instruct-v0.2
            (Includes enable_chunked_prefill = True)

    -v v1 : Use LMCacheConnectorV1 (default)
            model = meta-llama/Meta-Llama-3.1-8B-Instruct
            (Without enable_chunked_prefill)
17

18
19
20
21
Note that `lmcache` is needed to run this example.
Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
22
"""
23

24
import argparse
25
import contextlib
26
27
import os
import time
28
from dataclasses import asdict
29
30
31
32
33
34

from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
35
from vllm.engine.arg_utils import EngineArgs
36

37

38
def setup_environment_variables(vllm_version: str):
39
40
41
42
43
44
45
46
47
    # LMCache-related environment variables
    # Use experimental features in LMCache
    os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
    # LMCache is set to use 256 tokens per chunk
    os.environ["LMCACHE_CHUNK_SIZE"] = "256"
    # Enable local CPU backend in LMCache
    os.environ["LMCACHE_LOCAL_CPU"] = "True"
    # Set local CPU memory limit to 5.0 GB
    os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
48
49
    if vllm_version == "v0":
        os.environ["VLLM_USE_V1"] = "0"
50
51
52


@contextlib.contextmanager
53
def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
54
55
56
57
    ktc = KVTransferConfig(
        kv_connector=lmcache_connector,
        kv_role="kv_both",
    )
58
59
    # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
    # memory. Reduce the value if your GPU has less memory.
60
    # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    if vllm_version == "v0":
        llm_args = EngineArgs(
            model=model,
            kv_transfer_config=ktc,
            max_model_len=8000,
            gpu_memory_utilization=0.8,
            enable_chunked_prefill=True,  # Only in v0
        )
    else:
        llm_args = EngineArgs(
            model=model,
            kv_transfer_config=ktc,
            max_model_len=8000,
            gpu_memory_utilization=0.8,
        )

    llm = LLM(**asdict(llm_args))
78
79
80
81
82
83
84
85
86
87
88
89
90
    try:
        yield llm
    finally:
        # Clean up lmcache backend
        LMCacheEngineBuilder.destroy(ENGINE_NAME)


def print_output(
    llm: LLM,
    prompt: list[str],
    sampling_params: SamplingParams,
    req_str: str,
):
91
92
93
    # Should be able to see logs like the following:
    # `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
    # This indicates that the KV cache has been stored in LMCache.
94
95
96
97
98
99
    start = time.time()
    outputs = llm.generate(prompt, sampling_params)
    print("-" * 50)
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")
100
    print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
101
102
103
    print("-" * 50)


104
105
def parse_args():
    parser = argparse.ArgumentParser()
106
107
108
109
110
111
112
    parser.add_argument(
        "-v",
        "--version",
        choices=["v0", "v1"],
        default="v1",
        help="Specify vLLM version (default: v1)",
    )
113
114
115
    return parser.parse_args()


116
def main():
117
118
119
120
121
122
123
124
125
    args = parse_args()

    if args.version == "v0":
        lmcache_connector = "LMCacheConnector"
        model = "mistralai/Mistral-7B-Instruct-v0.2"
    else:
        lmcache_connector = "LMCacheConnectorV1"
        model = "meta-llama/Meta-Llama-3.1-8B-Instruct"

126
    setup_environment_variables(args.version)
127

128
    with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
129
130
131
132
133
134
135
136
137
138
        # This example script runs two requests with a shared prefix.
        # Define the shared prompt and specific prompts
        shared_prompt = "Hello, how are you?" * 1000
        first_prompt = [
            shared_prompt + "Hello, my name is",
        ]
        second_prompt = [
            shared_prompt + "Tell me a very long story",
        ]

139
        sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
140
141
142
143
144
145
146
147
148
149
150
151

        # Print the first output
        print_output(llm, first_prompt, sampling_params, "first")

        time.sleep(1)

        # print the second output
        print_output(llm, second_prompt, sampling_params, "second")


if __name__ == "__main__":
    main()