utils.py 2.85 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
import os

import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
9
from vllm.platforms import current_platform
10
11
12

TEST_MODELS = [
    ("facebook/opt-125m", {}),
13
14
15
16
    ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
        "dtype": torch.float16,
        "quantization": "compressed-tensors"
    }),
17
    ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
18
        "dtype": torch.float16,
19
        "quantization": "compressed-tensors"
20
    }),
21
    ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {
22
23
        "quantization": "compressed-tensors"
    }),
24
    ("meta-llama/Llama-3.2-1B-Instruct", {}),
25
26
]

27
if is_quant_method_supported("aqlm"):
28
29
30
31
    TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
        "quantization": "aqlm"
    }))

32
# TODO: figure out why this fails.
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
if False and is_quant_method_supported("gguf"):  # noqa: SIM223
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
        "quantization": "gguf"
    }))

if is_quant_method_supported("gptq"):
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
        "quantization": "gptq"
    }))

if is_quant_method_supported("gptq_marlin"):
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
        "quantization": "gptq_marlin"
    }))

if is_quant_method_supported("gptq_marlin_24"):
    TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
        "quantization": "gptq_marlin_24"
    }))

if is_quant_method_supported("marlin"):
    TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
        "quantization": "marlin"
    }))

58
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
59
60
61
62
63
    TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
        "quantization": "AWQ"
    }))


64
65
66
67
def check_full_graph_support(model,
                             model_kwargs,
                             optimization_level,
                             tp_size=1):
68
    # make sure these models can be captured in full graph mode
69
    os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
70

71
72
    print(f"MODEL={model}")

73
74
75
76
77
78
79
80
81
82
83
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(model=model,
              enforce_eager=True,
              tensor_parallel_size=tp_size,
              disable_custom_all_reduce=True,
84
              compilation_config=optimization_level,
85
86
87
88
89
90
91
92
93
              **model_kwargs)

    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")