utils.py 3.31 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
import os

import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
9
from vllm.platforms import current_platform
10
11
import os
from ..utils import models_path_prefix
12
13

TEST_MODELS = [
14
15
    (os.path.join(models_path_prefix, "facebook/opt-125m"), {}),
    (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"), {
16
17
18
        "dtype": torch.float16,
        "quantization": "compressed-tensors"
    }),
zhuwenwen's avatar
zhuwenwen committed
19
    # (os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic"), {
zhuwenwen's avatar
zhuwenwen committed
20
21
22
    #     "dtype": torch.float16,
    #     "quantization": "fp8"
    # }),
zhuwenwen's avatar
zhuwenwen committed
23
    (os.path.join(models_path_prefix, "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8"), {
24
25
        "quantization": "compressed-tensors"
    }),
zhuwenwen's avatar
zhuwenwen committed
26
    (os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), {}),
27
28
]

29
if is_quant_method_supported("aqlm"):
30
    TEST_MODELS.append((os.path.join(models_path_prefix, "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"), {
31
32
33
        "quantization": "aqlm"
    }))

34
# TODO: figure out why this fails.
35
if False and is_quant_method_supported("gguf"):  # noqa: SIM223
36
    TEST_MODELS.append((os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"), {
37
38
39
40
        "quantization": "gguf"
    }))

if is_quant_method_supported("gptq"):
41
    TEST_MODELS.append((os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"), {
42
43
44
        "quantization": "gptq"
    }))

zhuwenwen's avatar
zhuwenwen committed
45
46
47
48
# if is_quant_method_supported("gptq_marlin"):
#     TEST_MODELS.append((os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"), {
#         "quantization": "gptq_marlin"
#     }))
49

zhuwenwen's avatar
zhuwenwen committed
50
51
52
53
# if is_quant_method_supported("gptq_marlin_24"):
#     TEST_MODELS.append((os.path.join(models_path_prefix, "alexm-nm/tinyllama-24-marlin24-4bit-g128"), {
#         "quantization": "gptq_marlin_24"
#     }))
54

zhuwenwen's avatar
zhuwenwen committed
55
56
57
58
# if is_quant_method_supported("marlin"):
#     TEST_MODELS.append((os.path.join(models_path_prefix, "robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin"), {
#         "quantization": "marlin"
#     }))
59
60


61
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
62
    TEST_MODELS.append((os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"), {
63
64
65
66
        "quantization": "AWQ"
    }))


67
68
69
70
def check_full_graph_support(model,
                             model_kwargs,
                             optimization_level,
                             tp_size=1):
71
    # make sure these models can be captured in full graph mode
72
    os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
73

74
    print(f"MODEL={model}")
75
76
77
78
79
80
81
82
83
84
85
86

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(model=model,
              enforce_eager=True,
              tensor_parallel_size=tp_size,
              disable_custom_all_reduce=True,
87
              compilation_config=optimization_level,
88
89
90
91
92
93
94
95
96
              **model_kwargs)

    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")