test_aot_compile.py 7.32 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
import functools
import multiprocessing
6
7
8
9
10
11
import tempfile
from contextlib import contextmanager

import pytest
import torch

12
import vllm.model_executor.layers.activation
13
14
15
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
    CompilationConfig,
16
    CompilationMode,
17
18
19
    VllmConfig,
    set_current_vllm_config,
)
20
from vllm.envs import disable_envs_cache
21
from vllm.forward_context import set_forward_context
22
from vllm.utils.torch_utils import is_torch_equal_or_newer
23

24
25
from ..utils import create_new_process_for_each_test

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

def reference_fn(x: torch.Tensor):
    assert x.shape[0] <= 42
    assert x.shape[0] % 2 == 0
    for _ in range(3000):
        x = x + x.shape[0]
    return x


@support_torch_compile
class CompiledMod(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, x: torch.Tensor):
        return reference_fn(x)


def make_vllm_config() -> VllmConfig:
    return VllmConfig(
        compilation_config=CompilationConfig(
47
            mode=CompilationMode.VLLM_COMPILE,
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        )
    )


@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
    with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
        yield


@pytest.mark.skipif(
    not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        vllm_config = make_vllm_config()
        args = (torch.randn(10, 10),)
        expected = reference_fn(*args)
        with use_vllm_config(vllm_config):
            m.setenv("VLLM_USE_AOT_COMPILE", "0")
            with (
                pytest.raises(RuntimeError, match="Detected recompile"),
                torch.compiler.set_stance("fail_on_recompile"),
            ):
                CompiledMod(vllm_config=vllm_config)(*args)
73
            disable_envs_cache()
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

            m.setenv("VLLM_USE_AOT_COMPILE", "1")
            torch._dynamo.reset()
            with torch.compiler.set_stance("fail_on_recompile"):
                actual = CompiledMod(vllm_config=vllm_config)(*args)
            assert torch.allclose(actual, expected)


@pytest.mark.skipif(
    not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
    with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
        args = (torch.randn(10, 10),)
        m.setenv("VLLM_USE_AOT_COMPILE", "1")
        m.setenv("VLLM_FORCE_AOT_LOAD", "1")
        m.setenv("VLLM_CACHE_ROOT", tmpdirname)
        vllm_config = make_vllm_config()
        with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
            CompiledMod(vllm_config=vllm_config)(*args)


@pytest.mark.skipif(
    not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        args = (torch.randn(10, 10),)

        with tempfile.TemporaryDirectory() as tmpdirname:
            m.setenv("VLLM_CACHE_ROOT", tmpdirname)
            m.setenv("VLLM_USE_AOT_COMPILE", "1")
            vllm_config = make_vllm_config()
            with use_vllm_config(vllm_config):
                expected = CompiledMod(vllm_config=vllm_config)(*args)
109
            disable_envs_cache()
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

            m.setenv("VLLM_FORCE_AOT_LOAD", "1")
            vllm_config = make_vllm_config()
            with use_vllm_config(vllm_config):
                ret = CompiledMod(vllm_config=vllm_config)(*args)
            assert torch.allclose(ret, expected)


@pytest.mark.skipif(
    not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
    """
    Test that the shape environment is correctly serialized and preserved
    when loading from cache.
    """
    with monkeypatch.context() as m:
        args = (torch.randn(10, 10),)

        with tempfile.TemporaryDirectory() as tmpdirname:
            m.setenv("VLLM_CACHE_ROOT", tmpdirname)
            m.setenv("VLLM_USE_AOT_COMPILE", "1")
            vllm_config = make_vllm_config()
            with use_vllm_config(vllm_config):
                compiled_mod = CompiledMod(vllm_config=vllm_config)
                compiled_mod(*args)
                artifacts = compiled_mod.aot_compiled_fn._artifacts
                guards_string = artifacts.compiled_fn.shape_env.format_guards()
                assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
139
            disable_envs_cache()
140
141
142
143
144
145
146
147
148

            m.setenv("VLLM_FORCE_AOT_LOAD", "1")
            vllm_config = make_vllm_config()
            with use_vllm_config(vllm_config):
                compiled_mod = CompiledMod(vllm_config=vllm_config)
                compiled_mod(*args)
                artifacts = compiled_mod.aot_compiled_fn._artifacts
                guards_string = artifacts.compiled_fn.shape_env.format_guards()
                assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
149
150
151
152
153


@pytest.mark.skipif(
    not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
154
@create_new_process_for_each_test("spawn")
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
    """
    Test that compiling gpt2 twice results in a cache hit and
    capture torch dynamic symbol creations to ensure make_symbol
    not called on cache hit.
    """

    import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
    from torch.utils._sympy.symbol import make_symbol

    from vllm import LLM

    create_symbol_counter = multiprocessing.Value("i", 0)
    original_make_symbol = make_symbol

    @functools.wraps(original_make_symbol)
    def counting_make_symbol(prefix, idx, **kwargs):
        with create_symbol_counter.get_lock():
            create_symbol_counter.value += 1
        return original_make_symbol(prefix, idx, **kwargs)

    symbolic_shapes_module.make_symbol = counting_make_symbol
    try:
        with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
            m.setenv("VLLM_CACHE_ROOT", tmpdirname)
            m.setenv("VLLM_USE_AOT_COMPILE", "1")
            # First compilation - initialize model and generate
            llm_model = LLM(
                model="gpt2",
                compilation_config=CompilationConfig(
                    mode=CompilationMode.VLLM_COMPILE,
                ),
                max_model_len=256,
            )

            llm_model.generate("Hello, my name is")
            assert create_symbol_counter.value == 2
            create_symbol_counter.value = 0

            # Clean up first model
            del llm_model
196
197
            disable_envs_cache()
            vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

            # Second compilation - should hit cache
            m.setenv("VLLM_FORCE_AOT_LOAD", "1")
            llm_model = LLM(
                model="gpt2",
                compilation_config=CompilationConfig(
                    mode=CompilationMode.VLLM_COMPILE,
                ),
                max_model_len=256,
            )
            llm_model.generate("Hello, my name is")

            assert create_symbol_counter.value == 0

    finally:
        # Restore original method
        symbolic_shapes_module.make_symbol = original_make_symbol