test_simple.py 3.37 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
7
import pytest
8
9
import torch
from torch import nn
10
from torch.library import Library
11
12
13

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
14
15
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
                         set_current_vllm_config)
16
from vllm.envs import VLLM_USE_V1
17
from vllm.forward_context import set_forward_context
18
from vllm.utils import direct_register_custom_op
19
20
21

global_counter = 0

22
23
24
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT")  # noqa

25
26
27
28
29
30
31
32
33
34

def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    out: torch.Tensor) -> None:
    global global_counter
    global_counter += 1
    print(f"{global_counter=}")
    out.copy_(q)
    out[0] += 1


35
36
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                         out: torch.Tensor) -> None:
37
38
39
    return


40
41
42
43
44
45
46
47
48
direct_register_custom_op(
    op_name="attention",
    op_func=silly_attention,
    mutates_args=["out"],
    fake_impl=silly_attention_fake,
    target_lib=silly_lib,
)


49
50
51
@support_torch_compile
class SillyModel(nn.Module):

52
53
54
55
56
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = '',
                 **kwargs) -> None:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
        x += 1
        x[0] += 2
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


80
81
@pytest.mark.parametrize("use_inductor", [True, False])
def test_simple_piecewise_compile(use_inductor):
82
    assert VLLM_USE_V1
83

84
85
86
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_cudagraph=True,
87
        use_inductor=use_inductor,
88
        splitting_ops=["silly.attention"],
89
        cudagraph_copy_inputs=True,
90
        cudagraph_capture_sizes=[1, 2],
91
    ))
92
93
    with set_current_vllm_config(vllm_config):
        model = SillyModel(vllm_config=vllm_config, prefix='')
94

95
    inputs = torch.randn(100).cuda()
96
97
98
99
100

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=5,  # 2 * num_layers + 1
            num_piecewise_capturable_graphs_seen=3,  # 1 + num_layers
101
            num_backend_compilations=3,  # num_piecewise_capturable_graphs_seen
102
            num_cudagraph_captured=
103
            6,  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
104
    ), set_forward_context({}, vllm_config=vllm_config):
105

106
        model(inputs)
107

108
109
        model(torch.randn(2).cuda())
        model(torch.randn(1).cuda())
110

111
        input = torch.zeros(2).cuda()
112
113
        global global_counter
        global_counter = 0
114
        output = model(input)
115
116
        assert global_counter == 2
        assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))