"examples/hubert/dataset/hubert_dataset.py" did not exist on "7d092896a59e8f91dcf59e9ee4ccd0fbe93166da"
mlpspeculator.py 2.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding.

Note that still not support `v1`:
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
"""

zhuwenwen's avatar
zhuwenwen committed
11
12
13
14
15
16
import gc
import time

from vllm import LLM, SamplingParams


17
18
19
def time_generation(
    llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
):
zhuwenwen's avatar
zhuwenwen committed
20
21
22
23
24
25
26
27
    # Generate texts from the prompts. The output is a list of RequestOutput
    # objects that contain the prompt, generated text, and other information.
    # Warmup first
    llm.generate(prompts, sampling_params)
    llm.generate(prompts, sampling_params)
    start = time.time()
    outputs = llm.generate(prompts, sampling_params)
    end = time.time()
28
29
30
    print("-" * 50)
    print(title)
    print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
    # Print the outputs.
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"text: {generated_text!r}")
35
        print("-" * 50)
zhuwenwen's avatar
zhuwenwen committed
36
37


38
def main():
zhuwenwen's avatar
zhuwenwen committed
39
40
41
    template = (
        "Below is an instruction that describes a task. Write a response "
        "that appropriately completes the request.\n\n### Instruction:\n{}"
42
43
        "\n\n### Response:\n"
    )
zhuwenwen's avatar
zhuwenwen committed
44
45
46
47
48
49
50
51
52
53
54
55

    # Sample prompts.
    prompts = [
        "Write about the president of the United States.",
    ]
    prompts = [template.format(prompt) for prompt in prompts]
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=200)

    # Create an LLM without spec decoding
    llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")

56
    time_generation(llm, prompts, sampling_params, "Without speculation")
zhuwenwen's avatar
zhuwenwen committed
57
58
59
60
61
62
63

    del llm
    gc.collect()

    # Create an LLM with spec decoding
    llm = LLM(
        model="meta-llama/Llama-2-13b-chat-hf",
64
65
66
        speculative_config={
            "model": "ibm-ai-platform/llama-13b-accelerator",
        },
zhuwenwen's avatar
zhuwenwen committed
67
68
    )

69
70
71
72
73
    time_generation(llm, prompts, sampling_params, "With speculation")


if __name__ == "__main__":
    main()