test_full_graph.py 4.38 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from __future__ import annotations

5
from typing import Any, Union
6

7
import pytest
8
import torch
9

10
11
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
12
from vllm.config import CompilationConfig, CompilationLevel
13
from vllm.platforms import current_platform
14

15
from ..utils import create_new_process_for_each_test
16
17


18
def models_list(all: bool):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
        ("facebook/opt-125m", {}),
        ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
            "dtype": torch.float16,
            "quantization": "compressed-tensors"
        }),
        ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
            "dtype": torch.float16,
            "quantization": "compressed-tensors"
        }),
        ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {
            "quantization": "compressed-tensors"
        }),
        ("meta-llama/Llama-3.2-1B-Instruct", {}),
    ]

35
36
37
    if not all:
        return TEST_MODELS

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    if is_quant_method_supported("aqlm"):
        TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
            "quantization": "aqlm"
        }))

    # TODO: figure out why this fails.
    if False and is_quant_method_supported("gguf"):  # noqa: SIM223
        TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
            "quantization": "gguf"
        }))

    if is_quant_method_supported("gptq"):
        TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
            "quantization": "gptq"
        }))

    if is_quant_method_supported("gptq_marlin"):
        TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
            "quantization": "gptq_marlin"
        }))

    if is_quant_method_supported("gptq_marlin_24"):
        TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
            "quantization": "gptq_marlin_24"
        }))

    if is_quant_method_supported("marlin"):
        TEST_MODELS.append(
            ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
                "quantization": "marlin"
            }))

    if not current_platform.is_rocm() and is_quant_method_supported("awq"):
        TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
            "quantization": "AWQ"
        }))

    return TEST_MODELS


78
79
@pytest.mark.parametrize(
    "optimization_level",
80
81
    [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
)
82
@pytest.mark.parametrize("model_info", models_list(all=True))
83
@create_new_process_for_each_test()
84
85
86
87
88
89
90
91
92
93
94
95
def test_full_graph(
    monkeypatch: pytest.MonkeyPatch,
    model_info: tuple[str, dict[str, Any]],
    optimization_level: int,
):
    model, model_kwargs = model_info

    with monkeypatch.context() as m:
        # make sure these models can be captured in full graph mode
        m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
        print(f"MODEL={model}")

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        run_model(optimization_level, model, model_kwargs)


# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
    "compilation_config",
    # additional compile sizes
    [
        CompilationConfig(level=CompilationLevel.PIECEWISE,
                          compile_sizes=[1, 2])
    ])
# only test some of the models
@pytest.mark.parametrize("model_info", models_list(all=False))
@create_new_process_for_each_test()
def test_custom_compile_config(
    model_info: tuple[str, dict[str, Any]],
    compilation_config: CompilationConfig,
):
    model, model_kwargs = model_info
    print(f"MODEL={model}")
    run_model(compilation_config, model, model_kwargs)


def run_model(compile_config: Union[int, CompilationConfig], model: str,
              model_kwargs: dict[str, Any]):
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(
        model=model,
        enforce_eager=True,
        tensor_parallel_size=1,
        disable_custom_all_reduce=True,
        compilation_config=compile_config,
        **model_kwargs,
    )
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")