test_full_graph.py 7.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import logging
7
import tempfile
8
from typing import Any, Optional, Union
9

10
import pytest
11
import torch
12

13
from tests.quantization.utils import is_quant_method_supported
14
from tests.v1.attention.utils import _Backend
15
from vllm import LLM, SamplingParams
16
17
18
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
                         PassConfig)
19
from vllm.platforms import current_platform
20
from vllm.utils import is_torch_equal_or_newer
21

22
from ..utils import create_new_process_for_each_test
23
24


25
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
26
27
28
29
30
31
32
33
    TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
        ("facebook/opt-125m", {}),
        ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
            "dtype": torch.float16,
        }),
        ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
            "dtype": torch.float16,
        }),
34
        ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
35
36
37
        ("meta-llama/Llama-3.2-1B-Instruct", {}),
    ]

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    if all:

        # TODO: figure out why this fails.
        if False and is_quant_method_supported("gguf"):  # noqa: SIM223
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
                "quantization": "gguf"
            }))

        if is_quant_method_supported("gptq"):
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
                "quantization": "gptq"
            }))

        if is_quant_method_supported("gptq_marlin"):
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
                "quantization": "gptq_marlin"
            }))
55

56
57
58
        if is_quant_method_supported("gptq_marlin_24"):
            TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
                "quantization": "gptq_marlin_24"
59
60
            }))

61
62
63
64
65
66
67
68
69
70
71
        if not current_platform.is_rocm() and is_quant_method_supported("awq"):
            TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
                "quantization": "AWQ"
            }))

    if keywords is None:
        return TEST_MODELS

    # filter by keywords
    pred = lambda model: any(keyword in model[0] for keyword in keywords)
    return list(filter(pred, TEST_MODELS))
72
73


74
75
@pytest.mark.parametrize(
    "optimization_level",
76
77
    [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
)
78
@pytest.mark.parametrize("model_info", models_list(all=True))
79
@create_new_process_for_each_test()
80
81
82
83
84
85
86
def test_full_graph(
    monkeypatch: pytest.MonkeyPatch,
    model_info: tuple[str, dict[str, Any]],
    optimization_level: int,
):
    model, model_kwargs = model_info

87
    with monkeypatch.context():
88
89
        print(f"MODEL={model}")

90
91
92
93
94
        run_model(optimization_level, model, model_kwargs)


# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
95
    "compilation_config, model_info",
96
    [
97
98
99
100
101
102
103
104
105
106
107
        # additional compile sizes, only some of the models
        (CompilationConfig(level=CompilationLevel.PIECEWISE,
                           compile_sizes=[1, 2]), model)
        for model in models_list(all=False)
    ] + [
        # RMSNorm + quant fusion, only 8-bit quant models
        (CompilationConfig(level=CompilationLevel.PIECEWISE,
                           custom_ops=["+rms_norm"],
                           pass_config=PassConfig(enable_fusion=True,
                                                  enable_noop=True)), model)
        for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
108
109
110
111
112
    ] + [
        # Test depyf integration works
        (CompilationConfig(level=CompilationLevel.PIECEWISE,
                           debug_dump_path=tempfile.gettempdir()),
         ("facebook/opt-125m", {})),
113
114
115
116
117
118
119
120
121
122
123
124
    ] + [
        # graph inductor partition
        (
            CompilationConfig(
                level=CompilationLevel.PIECEWISE,
                # inductor graph partition uses
                # torch._C.Tag.cudagraph_unsafe to specify splitting ops
                use_inductor_graph_partition=True,
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
                compile_sizes=[1, 2]),
            model) for model in models_list(all=False)
        if is_torch_equal_or_newer("2.9.0.dev")
125
126
127
128
129
    ])
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
    compilation_config: CompilationConfig,
130
    model_info: tuple[str, dict[str, Any]],
131
):
132
133
134
135
136
    if (compilation_config.use_inductor_graph_partition
            and not is_torch_equal_or_newer("2.9.0.dev")):
        pytest.skip("inductor graph partition is only available "
                    "in PyTorch 2.9+")

137
138
139
140
141
    model, model_kwargs = model_info
    print(f"MODEL={model}")
    run_model(compilation_config, model, model_kwargs)


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@pytest.mark.parametrize(
    "optimization_level",
    [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
)
def test_fp8_kv_scale_compile(optimization_level: int):
    model = "Qwen/Qwen2-0.5B"
    model_kwargs = {
        "quantization": "fp8",
        "kv_cache_dtype": "fp8_e4m3",
        "calculate_kv_scales": True,
        "max_model_len": 512,
    }
    run_model(optimization_level, model, model_kwargs)


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
    if not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("inductor graph partition is only available "
                    "in PyTorch 2.9+")

    model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
    compilation_config = CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_inductor_graph_partition=True,
        cudagraph_mode=CUDAGraphMode.PIECEWISE,
        custom_ops=["+quant_fp8"],
        pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
    )
    model_kwargs = {
        "kv_cache_dtype": "fp8",
        "max_model_len": 1024,
    }
    with caplog_vllm.at_level(
            logging.DEBUG), global_force_attn_backend_context_manager(
                _Backend.FLASHINFER):
        run_model(compilation_config, model, model_kwargs)

    try:
        assert ("Fused quantization onto 48 attention nodes"
                in caplog_vllm.text), caplog_vllm.text
    except AssertionError:
        # Note: this message is only triggered when the compilation goes
        # through the custom pass. Due to multiple layers of cache on
        # PyTorch side, the compilation of a graph may be cached such
        # that custom pass directly goes through cache. In this case,
        # we go through this branch and assert that the pass is not
        # triggered.
        assert "Fused quantization" not in caplog_vllm.text


192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def run_model(compile_config: Union[int, CompilationConfig], model: str,
              model_kwargs: dict[str, Any]):
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0)
    llm = LLM(
        model=model,
        enforce_eager=True,
        tensor_parallel_size=1,
        disable_custom_all_reduce=True,
        compilation_config=compile_config,
        **model_kwargs,
    )
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")