test_simple.py 6.44 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
7

8
import pytest
9
10
11
12
13
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
14
15
from vllm.config import (
    CompilationConfig,
16
    CompilationMode,
17
18
19
20
    CUDAGraphMode,
    VllmConfig,
    set_current_vllm_config,
)
21
from vllm.forward_context import BatchDescriptor, set_forward_context
22
from vllm.utils.torch_utils import is_torch_equal_or_newer
23

24
25
from ...utils import create_new_process_for_each_test

26
27
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
28
29


30
31
32
33
34
35
36
37
38
39
40
# Custom op that returns an unbacked symint during graph capture
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(x: torch.Tensor) -> int:
    return 3


@foo.register_fake
def _(x):
    return torch.library.get_ctx().new_dynamic_size()


41
42
@support_torch_compile
class SillyModel(nn.Module):
43
44
45
46
47
48
49
50
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        intermediate_unbacked=False,
        **kwargs,
    ) -> None:
51
        super().__init__()
52
        self.intermediate_unbacked = intermediate_unbacked
53
54
55
56

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
57
        x = 3 * x + 19
58
59
60
61
62
63
64
65
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
66
67
68
69
70
71
72

        if self.intermediate_unbacked:
            # Test for unbacked symints: the following is a fancy way to multiply by 1
            u0 = foo(x)
            ones = x.new_ones(x.shape[0], u0).sum(-1) / 3
            x = x * ones

73
74
75
76
77
78
79
80
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


81
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
82
83
84
def _run_simple_model(
    splitting_ops,
    use_inductor_graph_partition,
85
    backend,
86
87
88
89
    expected_num_piecewise_graphs_seen,
    expected_num_piecewise_capturable_graphs_seen,
    expected_num_backend_compilations,
    expected_num_cudagraph_captured,
90
91
    *,
    intermediate_unbacked=False,
92
):
93
94
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
95
            mode=CompilationMode.VLLM_COMPILE,
96
            backend=backend,
97
98
99
100
101
102
            splitting_ops=splitting_ops,
            use_inductor_graph_partition=use_inductor_graph_partition,
            cudagraph_copy_inputs=True,
            cudagraph_capture_sizes=[1, 2],
        )
    )
103
    with set_current_vllm_config(vllm_config):
104
105
106
107
108
        model = SillyModel(
            vllm_config=vllm_config,
            prefix="",
            intermediate_unbacked=intermediate_unbacked,
        )
109

110
    inputs = torch.randn(100).cuda()
111

112
113
    with (
        compilation_counter.expect(
114
            num_graphs_seen=1,  # one graph for the model
115
            num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
116
            num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
117
118
            num_backend_compilations=expected_num_backend_compilations,
            num_cudagraph_captured=expected_num_cudagraph_captured,
119
120
121
        ),
        set_forward_context(None, vllm_config=vllm_config),
    ):  # background context
122
        # warm up with background context
123
        model(inputs)
124

125
126
        # capturing/replaying should under context of cudagraph dispatching
        with set_forward_context(
127
128
129
130
131
132
133
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
134
135
            model(torch.randn(2).cuda())
        with set_forward_context(
136
137
138
139
140
141
142
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=1,
            ),
        ):
143
            model(torch.randn(1).cuda())
144

145
        input = torch.zeros(2).cuda()
146
        reset_global_counter()
147
        with set_forward_context(
148
149
150
151
152
153
154
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
155
            output = model(input)
156
157
        assert get_global_counter() == 2
        assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
158
159


160
@pytest.mark.parametrize("backend", ["inductor", "eager"])
161
@pytest.mark.parametrize("intermediate_unbacked", [True, False])
162
@torch.inference_mode()
163
@create_new_process_for_each_test("spawn")
164
def test_simple_piecewise_compile(backend, intermediate_unbacked):
165
    _run_simple_model(
166
        splitting_ops=["silly::attention"],
167
        use_inductor_graph_partition=False,
168
        backend=backend,
169
170
171
172
173
174
175
176
        # 2 * num_layers + 1
        expected_num_piecewise_graphs_seen=5,
        # 1 + num_layers
        expected_num_piecewise_capturable_graphs_seen=3,
        # num_piecewise_capturable_graphs_seen
        expected_num_backend_compilations=3,
        # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
        expected_num_cudagraph_captured=6,
177
        intermediate_unbacked=intermediate_unbacked,
178
179
180
181
    )


@torch.inference_mode()
182
def test_simple_inductor_graph_partition(monkeypatch):
183
    if not is_torch_equal_or_newer("2.9.0.dev"):
184
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
185

186
187
188
189
    # disable compile cache so that we run separately for different splitting_ops
    # and get the expected number of cudagraphs captured.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

190
    _run_simple_model(
191
        splitting_ops=["silly::attention"],
192
        use_inductor_graph_partition=True,
193
        backend="inductor",
194
195
196
197
198
199
200
201
        # Since not splitting at fx graph level
        expected_num_piecewise_graphs_seen=1,
        # Since not splitting at fx graph level
        expected_num_piecewise_capturable_graphs_seen=1,
        # Since not splitting at fx graph level
        expected_num_backend_compilations=1,
        # Inductor graph partition still captures 6 graph, same as fx graph partition
        expected_num_cudagraph_captured=6,
202
    )