test_full_graph.py 8.15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import tempfile
5
from pathlib import Path
6
from typing import Any
7

8
import pytest
9
import torch
10

11
12
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
13
from vllm.attention.backends.registry import AttentionBackendEnum
14
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
15
from vllm.platforms import current_platform
16
from vllm.utils.torch_utils import is_torch_equal_or_newer
17

18
from ...utils import create_new_process_for_each_test
19
20


21
def models_list(*, all: bool = True, keywords: list[str] | None = None):
22
23
    TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
        ("facebook/opt-125m", {}),
24
25
        (
            "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
26
            {"dtype": torch.float16},
27
        ),
28
29
30
        ("meta-llama/Llama-3.2-1B-Instruct", {}),
    ]

31
    if all:
32
33
34
35
36
37
38
39
40
41
        TEST_MODELS.extend(
            [
                ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
                (
                    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
                    {"dtype": torch.float16},
                ),
            ]
        )

42
43
        # TODO: figure out why this fails.
        if False and is_quant_method_supported("gguf"):  # noqa: SIM223
44
45
46
            TEST_MODELS.append(
                ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
            )
47
48

        if is_quant_method_supported("gptq"):
49
50
51
            TEST_MODELS.append(
                ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
            )
52
53

        if is_quant_method_supported("gptq_marlin"):
54
55
56
57
58
59
            TEST_MODELS.append(
                (
                    "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
                    {"quantization": "gptq_marlin"},
                )
            )
60

61
        if is_quant_method_supported("gptq_marlin_24"):
62
63
64
            TEST_MODELS.append(
                (
                    "alexm-nm/tinyllama-24-marlin24-4bit-g128",
65
66
67
68
                    {
                        "quantization": "gptq_marlin_24",
                        "allow_deprecated_quantization": True,
                    },
69
70
                )
            )
71

72
        if not current_platform.is_rocm() and is_quant_method_supported("awq"):
73
74
75
            TEST_MODELS.append(
                ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
            )
76
77
78
79
80
81
82

    if keywords is None:
        return TEST_MODELS

    # filter by keywords
    pred = lambda model: any(keyword in model[0] for keyword in keywords)
    return list(filter(pred, TEST_MODELS))
83
84


85
@pytest.mark.parametrize(
86
87
    "compilation_mode",
    [CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
88
)
89
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
90
@create_new_process_for_each_test()
91
92
def test_full_graph(
    monkeypatch: pytest.MonkeyPatch,
93
94
    model: str,
    model_kwargs: dict[str, Any],
95
    compilation_mode: int,
96
):
97
98
99
100
101
102
103
    if (
        "w8a8" in model
        or "w8w8" in model
        and current_platform.has_device_capability((10, 0))
    ):
        # int8 removed on Blackwell:
        pytest.skip("int8 support removed on Blackwell")
104

105
    with monkeypatch.context():
106
107
        print(f"MODEL={model}")

108
        run_model(compilation_mode, model, **model_kwargs)
109
110
111
112


# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
113
    "compilation_config, model, model_kwargs",
114
    [
115
        # additional compile sizes, only some of the models
116
        (
117
            CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
118
            *model_info,
119
        )
120
        for model_info in models_list(all=False)
121
122
    ]
    + [
123
        # RMSNorm + quant fusion, only 8-bit quant models
124
125
        (
            CompilationConfig(
126
                mode=CompilationMode.VLLM_COMPILE,
127
                custom_ops=["+rms_norm"],
128
129
130
                pass_config=PassConfig(
                    fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
                ),
131
            ),
132
            *model_info,
133
        )
134
        for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
135
136
    ]
    + [
137
        # Test depyf integration works
138
139
        (
            CompilationConfig(
140
                mode=CompilationMode.VLLM_COMPILE,
141
                debug_dump_path=Path(tempfile.gettempdir()),
142
            ),
143
144
            "facebook/opt-125m",
            {},
145
146
147
        ),
    ]
    + [
148
149
150
        # graph inductor partition
        (
            CompilationConfig(
151
                mode=CompilationMode.VLLM_COMPILE,
152
153
154
155
                # inductor graph partition uses
                # torch._C.Tag.cudagraph_unsafe to specify splitting ops
                use_inductor_graph_partition=True,
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
156
157
                compile_sizes=[1, 2],
            ),
158
            *model_info,
159
        )
160
        for model_info in models_list(all=False)
161
        if is_torch_equal_or_newer("2.9.0.dev")
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    ]
    + [
        # Test get_raw_stream patch with compile_sizes
        # This tests that TorchInductor autotune works correctly with get_raw_stream
        # patch in torch 2.9 and without patch in torch 2.10+
        (
            CompilationConfig(
                mode=CompilationMode.VLLM_COMPILE,
                compile_sizes=[1, 2],  # Triggers autotune which uses get_raw_stream
                cudagraph_mode=CUDAGraphMode.NONE,
            ),
            "facebook/opt-125m",
            {},
        ),
176
177
    ],
)
178
179
180
181
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
    compilation_config: CompilationConfig,
182
183
    model: str,
    model_kwargs: dict[str, Any],
184
):
185
186
187
188
189
190
191
192
    if (
        "w8a8" in model
        or "w8w8" in model
        and current_platform.has_device_capability((10, 0))
    ):
        # int8 removed on Blackwell:
        pytest.skip("int8 support removed on Blackwell")

193
194
195
196
    if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
        "2.9.0.dev"
    ):
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
197

198
    print(f"MODEL={model}")
199
    run_model(compilation_config, model, **model_kwargs)
200
201


202
@pytest.mark.parametrize(
203
204
    "compilation_mode",
    [CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
205
)
206
@pytest.mark.parametrize(
207
    "model, backend",
208
    [
209
210
211
212
213
        ("Qwen/Qwen2-0.5B", None),  # Standard attention model
        (
            "deepseek-ai/DeepSeek-V2-Lite",
            AttentionBackendEnum.FLASHINFER_MLA,
        ),  # MLA (Multi-head Latent Attention) model
214
215
    ],
)
216
217
218
219
220
def test_fp8_kv_scale_compile(
    compilation_mode: int,
    model: str,
    backend: AttentionBackendEnum | None,
):
221
222
223
224
225
226
    model_kwargs = {
        "quantization": "fp8",
        "kv_cache_dtype": "fp8_e4m3",
        "calculate_kv_scales": True,
        "max_model_len": 512,
    }
227
228
229
    if backend:
        model_kwargs["attention_config"] = {"backend": backend.name}

230
    run_model(compilation_mode, model, **model_kwargs)
231
232


233
234
235
236
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
    compilation_config = (
        compile_config
        if isinstance(compile_config, CompilationConfig)
237
        else CompilationConfig(mode=compile_config)
238
239
    )

240
241
242
243
244
245
246
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
247
248
249
250
251
252
253
254
    # Allow override from model_kwargs
    model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
    model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}

    # No cudagraphs by default
    if compilation_config.cudagraph_mode is None:
        compilation_config.cudagraph_mode = CUDAGraphMode.NONE

255
256
    llm = LLM(
        model=model,
257
        compilation_config=compilation_config,
258
259
260
261
262
263
264
265
266
        **model_kwargs,
    )
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")