test_simple.py 3.08 KB
Newer Older
1
2
3
4
5
6
7
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""

import torch
from torch import nn
8
from torch.library import Library
9
10
11
12

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
14
from vllm.plugins import set_current_vllm_config
15
from vllm.utils import direct_register_custom_op
16
17
18

global_counter = 0

19
20
21
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT")  # noqa

22
23
24
25
26
27
28
29
30
31

def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    out: torch.Tensor) -> None:
    global global_counter
    global_counter += 1
    print(f"{global_counter=}")
    out.copy_(q)
    out[0] += 1


32
33
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                         out: torch.Tensor) -> None:
34
35
36
    return


37
38
39
40
41
42
43
44
45
direct_register_custom_op(
    op_name="attention",
    op_func=silly_attention,
    mutates_args=["out"],
    fake_impl=silly_attention_fake,
    target_lib=silly_lib,
)


46
47
48
@support_torch_compile
class SillyModel(nn.Module):

49
50
51
52
53
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = '',
                 **kwargs) -> None:
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
        x += 1
        x[0] += 2
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


def test_simple_piecewise_compile():

79
80
81
82
83
84
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_cudagraph=True,
        non_cudagraph_ops=["silly.attention"],
        cudagraph_copy_inputs=True,
    ))
85
86
    with set_current_vllm_config(vllm_config):
        model = SillyModel(vllm_config=vllm_config, prefix='')
87

88
    inputs = torch.randn(100).cuda()
89
90
91
92
93
94
95
96
97
98
99

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=5,  # 2 * num_layers + 1
            num_piecewise_capturable_graphs_seen=3,  # 1 + num_layers
            num_inductor_compilations=3,  # num_piecewise_capturable_graphs_seen
            num_cudagraph_caputured=
            6,  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
    ):

        with set_compile_context([1, 2]):
100
            model(inputs)
101

102
103
            model(torch.randn(2).cuda())
            model(torch.randn(1).cuda())
104

105
        input = torch.zeros(2).cuda()
106
107
        global global_counter
        global_counter = 0
108
        output = model(input)
109
110
        assert global_counter == 2
        assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))