utils.py 3.04 KB
Newer Older
1
2
3
4
5
6
import os

import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
7
from vllm.compilation.levels import CompilationLevel
8
from vllm.platforms import current_platform
9
10
11

TEST_MODELS = [
    ("facebook/opt-125m", {}),
12
13
14
15
    ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
        "dtype": torch.float16,
        "quantization": "compressed-tensors"
    }),
16
17
18
19
    ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
        "dtype": torch.float16,
        "quantization": "fp8"
    }),
20
21
22
    ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
        "quantization": "compressed-tensors"
    }),
23
24
25
    ("meta-llama/Meta-Llama-3-8B", {}),
]

26
if is_quant_method_supported("aqlm"):
27
28
29
30
    TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
        "quantization": "aqlm"
    }))

31
# TODO: figure out why this fails.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
if False and is_quant_method_supported("gguf"):  # noqa: SIM223
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
        "quantization": "gguf"
    }))

if is_quant_method_supported("gptq"):
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
        "quantization": "gptq"
    }))

if is_quant_method_supported("gptq_marlin"):
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
        "quantization": "gptq_marlin"
    }))

if is_quant_method_supported("gptq_marlin_24"):
    TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
        "quantization": "gptq_marlin_24"
    }))

if is_quant_method_supported("marlin"):
    TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
        "quantization": "marlin"
    }))

57
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
58
59
60
61
62
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
        "quantization": "AWQ"
    }))


63
64
65
66
def check_full_graph_support(model,
                             model_kwargs,
                             optimization_level,
                             tp_size=1):
67
    # make sure these models can be captured in full graph mode
68
69
    os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
    os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
70

71
72
    # The base meta llama uses too much memory.
    if (model == "meta-llama/Meta-Llama-3-8B"
73
            and optimization_level >= CompilationLevel.PIECEWISE):
74
75
        return

76
77
    print(f"MODEL={model}")

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(model=model,
              enforce_eager=True,
              tensor_parallel_size=tp_size,
              disable_custom_all_reduce=True,
              **model_kwargs)

    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")