test_simple.py 5.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
7

8
import pytest
9
10
11
12
13
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
14
15
from vllm.config import (
    CompilationConfig,
16
    CompilationMode,
17
18
19
20
    CUDAGraphMode,
    VllmConfig,
    set_current_vllm_config,
)
21
from vllm.forward_context import BatchDescriptor, set_forward_context
22
from vllm.utils.torch_utils import is_torch_equal_or_newer
23

24
25
from ...utils import create_new_process_for_each_test

26
27
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
28
29


30
31
@support_torch_compile
class SillyModel(nn.Module):
32
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
33
34
35
36
37
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
38
        x = 3 * x + 19
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


55
56
57
58
59
60
61
62
63
def _run_simple_model(
    splitting_ops,
    use_inductor_graph_partition,
    use_inductor,
    expected_num_piecewise_graphs_seen,
    expected_num_piecewise_capturable_graphs_seen,
    expected_num_backend_compilations,
    expected_num_cudagraph_captured,
):
64
65
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
66
            mode=CompilationMode.VLLM_COMPILE,
67
68
69
70
71
72
73
            use_inductor=use_inductor,
            splitting_ops=splitting_ops,
            use_inductor_graph_partition=use_inductor_graph_partition,
            cudagraph_copy_inputs=True,
            cudagraph_capture_sizes=[1, 2],
        )
    )
74
    with set_current_vllm_config(vllm_config):
75
        model = SillyModel(vllm_config=vllm_config, prefix="")
76

77
    inputs = torch.randn(100).cuda()
78

79
80
    with (
        compilation_counter.expect(
81
            num_graphs_seen=1,  # one graph for the model
82
            num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
83
            num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
84
85
            num_backend_compilations=expected_num_backend_compilations,
            num_cudagraph_captured=expected_num_cudagraph_captured,
86
87
88
        ),
        set_forward_context(None, vllm_config=vllm_config),
    ):  # background context
89
        # warm up with background context
90
        model(inputs)
91

92
93
        # capturing/replaying should under context of cudagraph dispatching
        with set_forward_context(
94
95
96
97
98
99
100
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
101
102
            model(torch.randn(2).cuda())
        with set_forward_context(
103
104
105
106
107
108
109
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=1,
            ),
        ):
110
            model(torch.randn(1).cuda())
111

112
        input = torch.zeros(2).cuda()
113
        reset_global_counter()
114
        with set_forward_context(
115
116
117
118
119
120
121
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
122
            output = model(input)
123
124
        assert get_global_counter() == 2
        assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
125
126
127
128


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
129
@create_new_process_for_each_test("spawn")
130
131
def test_simple_piecewise_compile(use_inductor):
    _run_simple_model(
132
        splitting_ops=["silly::attention"],
133
134
        use_inductor_graph_partition=False,
        use_inductor=use_inductor,
135
136
137
138
139
140
141
142
        # 2 * num_layers + 1
        expected_num_piecewise_graphs_seen=5,
        # 1 + num_layers
        expected_num_piecewise_capturable_graphs_seen=3,
        # num_piecewise_capturable_graphs_seen
        expected_num_backend_compilations=3,
        # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
        expected_num_cudagraph_captured=6,
143
144
145
146
    )


@torch.inference_mode()
147
def test_simple_inductor_graph_partition(monkeypatch):
148
    if not is_torch_equal_or_newer("2.9.0.dev"):
149
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
150

151
152
153
154
    # disable compile cache so that we run separately for different splitting_ops
    # and get the expected number of cudagraphs captured.
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

155
    _run_simple_model(
156
        splitting_ops=["silly::attention"],
157
158
        use_inductor_graph_partition=True,
        use_inductor=True,
159
160
161
162
163
164
165
166
        # Since not splitting at fx graph level
        expected_num_piecewise_graphs_seen=1,
        # Since not splitting at fx graph level
        expected_num_piecewise_capturable_graphs_seen=1,
        # Since not splitting at fx graph level
        expected_num_backend_compilations=1,
        # Inductor graph partition still captures 6 graph, same as fx graph partition
        expected_num_cudagraph_captured=6,
167
    )