test_wrapper.py 4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
7
import os

import pytest
8
9
import torch

10
11
12
13
14
15
16
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
from vllm.config import (
    CompilationConfig,
    CompilationMode,
    VllmConfig,
    set_current_vllm_config,
)
17
18
19


class MyMod(torch.nn.Module):
20
    def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
21
22
23
24
        if x.size()[0] >= 4:
            return x * 2
        else:
            return x * 100
25
26


27
class MyWrapper(TorchCompileWithNoGuardsWrapper):
28
29
    def __init__(self, model):
        self.model = model
30
        super().__init__()
31

32
    def forward(self, x: torch.Tensor):  # type: ignore[override]
33
        # this is the function to be compiled
34
35
        return self.model(x)

36

37
38
39
40
41
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
    """Test basic functionality of TorchCompileWithNoGuardsWrapper."""
    # Set the environment variable for this test
    monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
42

43
44
45
46
47
48
49
50
    # Create a proper vLLM config instead of mocking
    vllm_config = VllmConfig()
    vllm_config.compilation_config = CompilationConfig()
    vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
    vllm_config.compilation_config.backend = "inductor"

    # Test DYNAMO_TRACE_ONCE
    with set_current_vllm_config(vllm_config):
51
        torch._dynamo.reset()
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        mod = MyMod()
        wrapper = MyWrapper(mod)

        # First call should trigger compilation
        x = torch.tensor([1, 2, 3, 4])
        torch._dynamo.mark_dynamic(x, 0)

        result1 = wrapper(x)
        expected1 = torch.tensor([2, 4, 6, 8])
        assert torch.allclose(result1, expected1), (
            f"Expected {expected1}, got {result1}"
        )

        # Second call should use compiled code
        x2 = torch.tensor([1, 2, 3])
        result2 = wrapper(x2)
        expected2 = torch.tensor([2, 4, 6])
        assert torch.allclose(result2, expected2), (
            f"Expected {expected2}, got {result2}"
        )

        # without the wrapper result would be different.
        result3 = mod(x2)
        expected3 = torch.tensor([100, 200, 300])

        assert torch.allclose(result3, expected3), (
            f"Expected {result3}, got {expected3}"
        )

    # with STOCK_TORCH_COMPILE we do not remove guards.
    vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
    torch._dynamo.reset()
    with set_current_vllm_config(vllm_config):
        mod = MyMod()
86
        wrapper = MyWrapper(mod)
87
88
89
90
91
92
93
94
95
96
97

        # First call should trigger compilation
        x = torch.tensor([1, 2, 3, 4])
        torch._dynamo.mark_dynamic(x, 0)

        result1 = wrapper(x)
        expected1 = torch.tensor([2, 4, 6, 8])
        assert torch.allclose(result1, expected1), (
            f"Expected {expected1}, got {result1}"
        )

Jiayi Yan's avatar
Jiayi Yan committed
98
        # Second call should trigger another compilation
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        x2 = torch.tensor([1, 2, 3])
        result2 = wrapper(x2)
        expected2 = torch.tensor([100, 200, 300])
        assert torch.allclose(result2, expected2), (
            f"Expected {expected2}, got {result2}"
        )

    # NO_COMPILATION level not supported.
    vllm_config.compilation_config.mode = None
    torch._dynamo.reset()
    with set_current_vllm_config(vllm_config):
        torch._dynamo.reset()
        mod = MyMod()

        try:
            wrapper = MyWrapper(mod)
        except Exception:
            return
        raise AssertionError("expected an exception to be raised")


if __name__ == "__main__":
    # Run with both parameter values

    class MockMonkeypatch:
        def setenv(self, name, value):
            os.environ[name] = value

    mp = MockMonkeypatch()

    print("Testing with VLLM_USE_BYTECODE_HOOK=False")
    test_torch_compile_wrapper(False, mp)

    print("Testing with VLLM_USE_BYTECODE_HOOK=True")
    test_torch_compile_wrapper(True, mp)

    print("All tests passed!")