test_simple.py 5.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
7

8
import pytest
9
10
11
12
13
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
14
15
16
17
18
19
20
from vllm.config import (
    CompilationConfig,
    CompilationLevel,
    CUDAGraphMode,
    VllmConfig,
    set_current_vllm_config,
)
21
from vllm.forward_context import BatchDescriptor, set_forward_context
22
from vllm.utils import is_torch_equal_or_newer
23

24
25
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
26
27


28
29
@support_torch_compile
class SillyModel(nn.Module):
30
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
31
32
33
34
35
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
36
        x = 3 * x + 19
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


53
54
55
56
57
58
59
60
61
def _run_simple_model(
    splitting_ops,
    use_inductor_graph_partition,
    use_inductor,
    expected_num_piecewise_graphs_seen,
    expected_num_piecewise_capturable_graphs_seen,
    expected_num_backend_compilations,
    expected_num_cudagraph_captured,
):
62
63
64
65
66
67
68
69
70
71
72
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            use_cudagraph=True,
            use_inductor=use_inductor,
            splitting_ops=splitting_ops,
            use_inductor_graph_partition=use_inductor_graph_partition,
            cudagraph_copy_inputs=True,
            cudagraph_capture_sizes=[1, 2],
        )
    )
73
    with set_current_vllm_config(vllm_config):
74
        model = SillyModel(vllm_config=vllm_config, prefix="")
75

76
    inputs = torch.randn(100).cuda()
77

78
79
    with (
        compilation_counter.expect(
80
            num_graphs_seen=1,  # one graph for the model
81
            num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
82
            num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
83
84
            num_backend_compilations=expected_num_backend_compilations,
            num_cudagraph_captured=expected_num_cudagraph_captured,
85
86
87
        ),
        set_forward_context(None, vllm_config=vllm_config),
    ):  # background context
88
        # warm up with background context
89
        model(inputs)
90

91
92
        # capturing/replaying should under context of cudagraph dispatching
        with set_forward_context(
93
94
95
96
97
98
99
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
100
101
            model(torch.randn(2).cuda())
        with set_forward_context(
102
103
104
105
106
107
108
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=1,
            ),
        ):
109
            model(torch.randn(1).cuda())
110

111
        input = torch.zeros(2).cuda()
112
        reset_global_counter()
113
        with set_forward_context(
114
115
116
117
118
119
120
            None,
            vllm_config=vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
            batch_descriptor=BatchDescriptor(
                num_tokens=2,
            ),
        ):
121
            output = model(input)
122
123
        assert get_global_counter() == 2
        assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
124
125
126
127
128
129
130
131
132


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
    _run_simple_model(
        splitting_ops=["silly.attention"],
        use_inductor_graph_partition=False,
        use_inductor=use_inductor,
133
134
135
136
137
138
139
140
        # 2 * num_layers + 1
        expected_num_piecewise_graphs_seen=5,
        # 1 + num_layers
        expected_num_piecewise_capturable_graphs_seen=3,
        # num_piecewise_capturable_graphs_seen
        expected_num_backend_compilations=3,
        # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
        expected_num_cudagraph_captured=6,
141
142
143
144
145
146
147
    )


@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
    if not is_torch_equal_or_newer("2.9.0.dev"):
148
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
149
150

    _run_simple_model(
151
        # Inductor graph partition automatically resets splitting_ops to an empty list
152
153
154
        splitting_ops=splitting_ops,
        use_inductor_graph_partition=True,
        use_inductor=True,
155
156
157
158
159
160
161
162
        # Since not splitting at fx graph level
        expected_num_piecewise_graphs_seen=1,
        # Since not splitting at fx graph level
        expected_num_piecewise_capturable_graphs_seen=1,
        # Since not splitting at fx graph level
        expected_num_backend_compilations=1,
        # Inductor graph partition still captures 6 graph, same as fx graph partition
        expected_num_cudagraph_captured=6,
163
    )