mlpspeculator.py 2.07 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding.

Note that still not support `v1`:
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
"""
9

10
11
12
13
14
15
import gc
import time

from vllm import LLM, SamplingParams


16
17
18
def time_generation(
    llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
):
19
20
21
22
23
24
25
26
    # Generate texts from the prompts. The output is a list of RequestOutput
    # objects that contain the prompt, generated text, and other information.
    # Warmup first
    llm.generate(prompts, sampling_params)
    llm.generate(prompts, sampling_params)
    start = time.time()
    outputs = llm.generate(prompts, sampling_params)
    end = time.time()
27
28
    print("-" * 50)
    print(title)
29
    print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
30
31
32
33
    # Print the outputs.
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"text: {generated_text!r}")
34
        print("-" * 50)
35
36


37
def main():
38
39
40
    template = (
        "Below is an instruction that describes a task. Write a response "
        "that appropriately completes the request.\n\n### Instruction:\n{}"
41
42
        "\n\n### Response:\n"
    )
43
44
45
46
47
48
49
50
51
52
53
54

    # Sample prompts.
    prompts = [
        "Write about the president of the United States.",
    ]
    prompts = [template.format(prompt) for prompt in prompts]
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=200)

    # Create an LLM without spec decoding
    llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")

55
    time_generation(llm, prompts, sampling_params, "Without speculation")
56
57
58
59
60
61
62

    del llm
    gc.collect()

    # Create an LLM with spec decoding
    llm = LLM(
        model="meta-llama/Llama-2-13b-chat-hf",
63
64
65
        speculative_config={
            "model": "ibm-ai-platform/llama-13b-accelerator",
        },
66
67
    )

68
    time_generation(llm, prompts, sampling_params, "With speculation")
69
70
71
72


if __name__ == "__main__":
    main()