lora_with_quantization_inference.py 4.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use LoRA with different quantization techniques
for offline inference.

Requires HuggingFace credentials for access.
"""

import gc
from typing import Optional

import torch
from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
    lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
    return [
        # this is an example of using quantization without LoRA
        (
            "My name is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            None,
        ),
        # the next three examples use quantization with LoRA
        (
            "my name is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            LoRARequest("lora-test-1", 1, lora_path),
        ),
        (
            "The capital of USA is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            LoRARequest("lora-test-2", 1, lora_path),
        ),
        (
            "The capital of France is",
            SamplingParams(
                temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
            ),
            LoRARequest("lora-test-3", 1, lora_path),
        ),
    ]


def process_requests(
    engine: LLMEngine,
    test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(
                str(request_id), prompt, sampling_params, lora_request=lora_request
            )
            request_id += 1

        request_outputs: list[RequestOutput] = engine.step()
        for request_output in request_outputs:
            if request_output.finished:
                print("----------------------------------------------------")
                print(f"Prompt: {request_output.prompt}")
                print(f"Output: {request_output.outputs[0].text}")


def initialize_engine(
    model: str, quantization: str, lora_repo: Optional[str]
) -> LLMEngine:
    """Initialize the LLMEngine."""

    engine_args = EngineArgs(
        model=model,
        quantization=quantization,
        enable_lora=True,
        max_lora_rank=64,
        max_loras=4,
    )
    return LLMEngine.from_engine_args(engine_args)


def main():
    """Main function that sets up and runs the prompt processing."""

    test_configs = [
        # QLoRA (https://arxiv.org/abs/2305.14314)
        {
            "name": "qlora_inference_example",
            "model": "huggyllama/llama-7b",
            "quantization": "bitsandbytes",
            "lora_repo": "timdettmers/qlora-flan-7b",
        },
        {
            "name": "AWQ_inference_with_lora_example",
            "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
            "quantization": "awq",
            "lora_repo": "jashing/tinyllama-colorist-lora",
        },
        {
            "name": "GPTQ_inference_with_lora_example",
            "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
            "quantization": "gptq",
            "lora_repo": "jashing/tinyllama-colorist-lora",
        },
    ]

    for test_config in test_configs:
        print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
        engine = initialize_engine(
            test_config["model"], test_config["quantization"], test_config["lora_repo"]
        )
        lora_path = snapshot_download(repo_id=test_config["lora_repo"])
        test_prompts = create_test_prompts(lora_path)
        process_requests(engine, test_prompts)

        # Clean up the GPU memory for the next test
        del engine
        gc.collect()
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()