test_full_graph.py 5.05 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
from typing import Any, Optional, Union
7

8
import pytest
9
import torch
10

11
12
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
13
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
14
from vllm.platforms import current_platform
15

16
from ..utils import create_new_process_for_each_test
17
18


19
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
20
21
22
23
24
25
26
27
    TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
        ("facebook/opt-125m", {}),
        ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
            "dtype": torch.float16,
        }),
        ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
            "dtype": torch.float16,
        }),
28
        ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
29
30
31
        ("meta-llama/Llama-3.2-1B-Instruct", {}),
    ]

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    if all:
        if is_quant_method_supported("aqlm"):
            TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
                "quantization": "aqlm"
            }))

        # TODO: figure out why this fails.
        if False and is_quant_method_supported("gguf"):  # noqa: SIM223
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
                "quantization": "gguf"
            }))

        if is_quant_method_supported("gptq"):
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
                "quantization": "gptq"
            }))

        if is_quant_method_supported("gptq_marlin"):
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
                "quantization": "gptq_marlin"
            }))
53

54
55
56
        if is_quant_method_supported("gptq_marlin_24"):
            TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
                "quantization": "gptq_marlin_24"
57
58
            }))

59
60
61
62
63
        if is_quant_method_supported("marlin"):
            TEST_MODELS.append(
                ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
                    "quantization": "marlin"
                }))
64

65
66
67
68
69
70
71
72
73
74
75
        if not current_platform.is_rocm() and is_quant_method_supported("awq"):
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
                "quantization": "AWQ"
            }))

    if keywords is None:
        return TEST_MODELS

    # filter by keywords
    pred = lambda model: any(keyword in model[0] for keyword in keywords)
    return list(filter(pred, TEST_MODELS))
76
77


78
79
@pytest.mark.parametrize(
    "optimization_level",
80
81
    [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
)
82
@pytest.mark.parametrize("model_info", models_list(all=True))
83
@create_new_process_for_each_test()
84
85
86
87
88
89
90
91
92
93
94
95
def test_full_graph(
    monkeypatch: pytest.MonkeyPatch,
    model_info: tuple[str, dict[str, Any]],
    optimization_level: int,
):
    model, model_kwargs = model_info

    with monkeypatch.context() as m:
        # make sure these models can be captured in full graph mode
        m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
        print(f"MODEL={model}")

96
97
98
99
100
        run_model(optimization_level, model, model_kwargs)


# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
101
    "compilation_config, model_info",
102
    [
103
104
105
106
107
108
109
110
111
112
113
        # additional compile sizes, only some of the models
        (CompilationConfig(level=CompilationLevel.PIECEWISE,
                           compile_sizes=[1, 2]), model)
        for model in models_list(all=False)
    ] + [
        # RMSNorm + quant fusion, only 8-bit quant models
        (CompilationConfig(level=CompilationLevel.PIECEWISE,
                           custom_ops=["+rms_norm"],
                           pass_config=PassConfig(enable_fusion=True,
                                                  enable_noop=True)), model)
        for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
114
115
116
117
118
    ])
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
    compilation_config: CompilationConfig,
119
    model_info: tuple[str, dict[str, Any]],
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
):
    model, model_kwargs = model_info
    print(f"MODEL={model}")
    run_model(compilation_config, model, model_kwargs)


def run_model(compile_config: Union[int, CompilationConfig], model: str,
              model_kwargs: dict[str, Any]):
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(
        model=model,
        enforce_eager=True,
        tensor_parallel_size=1,
        disable_custom_all_reduce=True,
        compilation_config=compile_config,
        **model_kwargs,
    )
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")