test_simple.py 5.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
7

8
import pytest
9
10
11
12
13
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
14
15
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
                         VllmConfig, set_current_vllm_config)
16
from vllm.envs import VLLM_USE_V1
17
from vllm.forward_context import BatchDescriptor, set_forward_context
18
from vllm.utils import is_torch_equal_or_newer
19

20
21
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
22
23


24
25
26
@support_torch_compile
class SillyModel(nn.Module):

27
28
29
30
31
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = '',
                 **kwargs) -> None:
32
33
34
35
36
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
37
        x = 3 * x + 19
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


54
55
56
57
58
59
60
61
62
def _run_simple_model(
    splitting_ops,
    use_inductor_graph_partition,
    use_inductor,
    expected_num_piecewise_graphs_seen,
    expected_num_piecewise_capturable_graphs_seen,
    expected_num_backend_compilations,
    expected_num_cudagraph_captured,
):
63
64
65
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_cudagraph=True,
66
        use_inductor=use_inductor,
67
68
        splitting_ops=splitting_ops,
        use_inductor_graph_partition=use_inductor_graph_partition,
69
        cudagraph_copy_inputs=True,
70
        cudagraph_capture_sizes=[1, 2],
71
    ))
72
73
    with set_current_vllm_config(vllm_config):
        model = SillyModel(vllm_config=vllm_config, prefix='')
74

75
    inputs = torch.randn(100).cuda()
76
77
78

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
79
80
81
82
83
            num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
            num_piecewise_capturable_graphs_seen=
            expected_num_piecewise_capturable_graphs_seen,
            num_backend_compilations=expected_num_backend_compilations,
            num_cudagraph_captured=expected_num_cudagraph_captured,
84
85
86
    ), set_forward_context(None,
                           vllm_config=vllm_config):  # background context
        # warm up with background context
87
        model(inputs)
88

89
90
91
92
93
94
95
96
97
98
99
100
101
        # capturing/replaying should under context of cudagraph dispatching
        with set_forward_context(
                None,
                vllm_config=vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
                batch_descriptor=BatchDescriptor(num_tokens=2, )):
            model(torch.randn(2).cuda())
        with set_forward_context(
                None,
                vllm_config=vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
                batch_descriptor=BatchDescriptor(num_tokens=1, )):
            model(torch.randn(1).cuda())
102

103
        input = torch.zeros(2).cuda()
104
        reset_global_counter()
105
106
107
108
109
110
        with set_forward_context(
                None,
                vllm_config=vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
                batch_descriptor=BatchDescriptor(num_tokens=2, )):
            output = model(input)
111
112
        assert get_global_counter() == 2
        assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
    assert VLLM_USE_V1
    _run_simple_model(
        splitting_ops=["silly.attention"],
        use_inductor_graph_partition=False,
        use_inductor=use_inductor,
        expected_num_piecewise_graphs_seen=5,  # 2 * num_layers + 1
        expected_num_piecewise_capturable_graphs_seen=3,  # 1 + num_layers
        expected_num_backend_compilations=
        3,  # num_piecewise_capturable_graphs_seen
        expected_num_cudagraph_captured=
        6,  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
    )


@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
    assert VLLM_USE_V1
    if not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("inductor graph partition is only available "
                    "in PyTorch 2.9+")

    _run_simple_model(
        # inductor graph partition automatically resets splitting_ops
        # to be an empty list
        splitting_ops=splitting_ops,
        use_inductor_graph_partition=True,
        use_inductor=True,
        expected_num_piecewise_graphs_seen=
        1,  # since not splitting at fx graph level
        expected_num_piecewise_capturable_graphs_seen=
        1,  # since not splitting at fx graph level
        expected_num_backend_compilations=
        1,  # since not splitting at fx graph level
        expected_num_cudagraph_captured=
        6,  # inductor graph partition still captures 6
        # graph, same as fx graph partition.
    )