test_simple.py 3.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
7

8
import pytest
9
10
11
12
13
import torch
from torch import nn

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
14
15
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
                         VllmConfig, set_current_vllm_config)
16
from vllm.envs import VLLM_USE_V1
17
from vllm.forward_context import BatchDescriptor, set_forward_context
18

19
20
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
21
22


23
24
25
@support_torch_compile
class SillyModel(nn.Module):

26
27
28
29
30
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = '',
                 **kwargs) -> None:
31
32
33
34
35
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Overall effect:
36
        x = 3 * x + 19
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        global_counter += 2
        """
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x


53
@pytest.mark.parametrize("use_inductor", [True, False])
54
@torch.inference_mode()
55
def test_simple_piecewise_compile(use_inductor):
56
    assert VLLM_USE_V1
57

58
59
60
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        use_cudagraph=True,
61
        use_inductor=use_inductor,
62
        splitting_ops=["silly.attention"],
63
        cudagraph_copy_inputs=True,
64
        cudagraph_capture_sizes=[1, 2],
65
    ))
66
67
    with set_current_vllm_config(vllm_config):
        model = SillyModel(vllm_config=vllm_config, prefix='')
68

69
    inputs = torch.randn(100).cuda()
70
71
72
73
74

    with compilation_counter.expect(
            num_graphs_seen=1,  # one graph for the model
            num_piecewise_graphs_seen=5,  # 2 * num_layers + 1
            num_piecewise_capturable_graphs_seen=3,  # 1 + num_layers
75
            num_backend_compilations=3,  # num_piecewise_capturable_graphs_seen
76
            num_cudagraph_captured=
77
            6,  # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
78
79
80
    ), set_forward_context(None,
                           vllm_config=vllm_config):  # background context
        # warm up with background context
81
        model(inputs)
82

83
84
85
86
87
88
89
90
91
92
93
94
95
        # capturing/replaying should under context of cudagraph dispatching
        with set_forward_context(
                None,
                vllm_config=vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
                batch_descriptor=BatchDescriptor(num_tokens=2, )):
            model(torch.randn(2).cuda())
        with set_forward_context(
                None,
                vllm_config=vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
                batch_descriptor=BatchDescriptor(num_tokens=1, )):
            model(torch.randn(1).cuda())
96

97
        input = torch.zeros(2).cuda()
98
        reset_global_counter()
99
100
101
102
103
104
        with set_forward_context(
                None,
                vllm_config=vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
                batch_descriptor=BatchDescriptor(num_tokens=2, )):
            output = model(input)
105
106
        assert get_global_counter() == 2
        assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))