"vllm/model_executor/models/ouro.py" did not exist on "196c34b0acdd19014feb6c065c324036407f3b36"
test_full_graph.py 7.66 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import logging
5
import tempfile
6
from typing import Any
7

8
import pytest
9
import torch
10

11
12
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
13
from vllm.attention.backends.registry import _Backend
14
from vllm.attention.selector import global_force_attn_backend_context_manager
15
from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig
16
from vllm.platforms import current_platform
17
from vllm.utils import is_torch_equal_or_newer
18

19
from ..utils import create_new_process_for_each_test
20
21


22
def models_list(*, all: bool = True, keywords: list[str] | None = None):
23
24
    TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
        ("facebook/opt-125m", {}),
25
26
27
28
29
30
31
32
33
34
35
36
        (
            "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
            {
                "dtype": torch.float16,
            },
        ),
        (
            "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
            {
                "dtype": torch.float16,
            },
        ),
37
        ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
38
39
40
        ("meta-llama/Llama-3.2-1B-Instruct", {}),
    ]

41
42
43
    if all:
        # TODO: figure out why this fails.
        if False and is_quant_method_supported("gguf"):  # noqa: SIM223
44
45
46
            TEST_MODELS.append(
                ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
            )
47
48

        if is_quant_method_supported("gptq"):
49
50
51
            TEST_MODELS.append(
                ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
            )
52
53

        if is_quant_method_supported("gptq_marlin"):
54
55
56
57
58
59
            TEST_MODELS.append(
                (
                    "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
                    {"quantization": "gptq_marlin"},
                )
            )
60

61
        if is_quant_method_supported("gptq_marlin_24"):
62
63
64
65
66
67
            TEST_MODELS.append(
                (
                    "alexm-nm/tinyllama-24-marlin24-4bit-g128",
                    {"quantization": "gptq_marlin_24"},
                )
            )
68

69
        if not current_platform.is_rocm() and is_quant_method_supported("awq"):
70
71
72
            TEST_MODELS.append(
                ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
            )
73
74
75
76
77
78
79

    if keywords is None:
        return TEST_MODELS

    # filter by keywords
    pred = lambda model: any(keyword in model[0] for keyword in keywords)
    return list(filter(pred, TEST_MODELS))
80
81


82
83
@pytest.mark.parametrize(
    "optimization_level",
84
85
    [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
)
86
@pytest.mark.parametrize("model_info", models_list(all=True))
87
@create_new_process_for_each_test()
88
89
90
91
92
93
94
def test_full_graph(
    monkeypatch: pytest.MonkeyPatch,
    model_info: tuple[str, dict[str, Any]],
    optimization_level: int,
):
    model, model_kwargs = model_info

95
    with monkeypatch.context():
96
97
        print(f"MODEL={model}")

98
99
100
101
102
        run_model(optimization_level, model, model_kwargs)


# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
103
    "compilation_config, model_info",
104
    [
105
        # additional compile sizes, only some of the models
106
107
108
109
        (
            CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]),
            model,
        )
110
        for model in models_list(all=False)
111
112
    ]
    + [
113
        # RMSNorm + quant fusion, only 8-bit quant models
114
115
116
117
118
119
120
121
        (
            CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                custom_ops=["+rms_norm"],
                pass_config=PassConfig(enable_fusion=True, enable_noop=True),
            ),
            model,
        )
122
        for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
123
124
    ]
    + [
125
        # Test depyf integration works
126
127
128
129
130
131
132
133
        (
            CompilationConfig(
                level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()
            ),
            ("facebook/opt-125m", {}),
        ),
    ]
    + [
134
135
136
137
138
139
140
141
        # graph inductor partition
        (
            CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                # inductor graph partition uses
                # torch._C.Tag.cudagraph_unsafe to specify splitting ops
                use_inductor_graph_partition=True,
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
142
143
144
145
146
                compile_sizes=[1, 2],
            ),
            model,
        )
        for model in models_list(all=False)
147
        if is_torch_equal_or_newer("2.9.0.dev")
148
149
    ],
)
150
151
152
153
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
    compilation_config: CompilationConfig,
154
    model_info: tuple[str, dict[str, Any]],
155
):
156
157
158
159
    if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
        "2.9.0.dev"
    ):
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
160

161
162
163
164
165
    model, model_kwargs = model_info
    print(f"MODEL={model}")
    run_model(compilation_config, model, model_kwargs)


166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@pytest.mark.parametrize(
    "optimization_level",
    [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
)
def test_fp8_kv_scale_compile(optimization_level: int):
    model = "Qwen/Qwen2-0.5B"
    model_kwargs = {
        "quantization": "fp8",
        "kv_cache_dtype": "fp8_e4m3",
        "calculate_kv_scales": True,
        "max_model_len": 512,
    }
    run_model(optimization_level, model, model_kwargs)


181
182
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
    if not is_torch_equal_or_newer("2.9.0.dev"):
183
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
184
185
186
187
188
189
190
191
192
193
194
195
196

    model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
    compilation_config = CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_inductor_graph_partition=True,
        cudagraph_mode=CUDAGraphMode.PIECEWISE,
        custom_ops=["+quant_fp8"],
        pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
    )
    model_kwargs = {
        "kv_cache_dtype": "fp8",
        "max_model_len": 1024,
    }
197
198
199
200
    with (
        caplog_vllm.at_level(logging.DEBUG),
        global_force_attn_backend_context_manager(_Backend.FLASHINFER),
    ):
201
202
203
        run_model(compilation_config, model, model_kwargs)

    try:
204
205
206
        assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
            caplog_vllm.text
        )
207
208
209
210
211
212
213
214
215
216
    except AssertionError:
        # Note: this message is only triggered when the compilation goes
        # through the custom pass. Due to multiple layers of cache on
        # PyTorch side, the compilation of a graph may be cached such
        # that custom pass directly goes through cache. In this case,
        # we go through this branch and assert that the pass is not
        # triggered.
        assert "Fused quantization" not in caplog_vllm.text


217
def run_model(
218
    compile_config: int | CompilationConfig,
219
220
221
    model: str,
    model_kwargs: dict[str, Any],
):
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(
        model=model,
        enforce_eager=True,
        tensor_parallel_size=1,
        disable_custom_all_reduce=True,
        compilation_config=compile_config,
        **model_kwargs,
    )
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")