test_compile_ranges.py 5.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

import torch
from torch import fx as fx
from torch import nn

# This import automatically registers `torch.ops.silly.attention`
import tests.compile.silly_attention  # noqa
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
13
from vllm.compilation.passes.inductor_pass import (
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    InductorPass,
    get_pass_context,
)
from vllm.config import (
    VllmConfig,
    set_current_vllm_config,
)
from vllm.config.compilation import CompilationConfig, CompilationMode
from vllm.config.scheduler import SchedulerConfig
from vllm.config.utils import Range
from vllm.forward_context import set_forward_context

BATCH_SIZE = 64
MLP_SIZE = 128


@support_torch_compile
class TestModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + x
        attn_output = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, attn_output)
        x = attn_output
        x = x * 3
        return x


@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]):
    with set_forward_context({}, vllm_config=vllm_config):
        model(torch.randn(BATCH_SIZE, MLP_SIZE))
        for batch_size in batch_sizes:
            model(torch.randn(batch_size, MLP_SIZE))


class PostGradRangeChecker(InductorPass):
    def __init__(self, ranges: list[Range]):
        self.ranges = ranges
        self.num_calls = 0

    def __call__(self, graph: fx.Graph):
        compile_range = get_pass_context().compile_range
        assert compile_range in self.ranges, (
            f"Compile range {compile_range} not in {self.ranges}"
        )
        self.num_calls += 1

    def uuid(self) -> str:
        state: dict[str, Any] = {}
        return InductorPass.hash_dict(state)


def test_compile_ranges(use_fresh_inductor_cache):
    post_grad_range_checker = PostGradRangeChecker(
        [
            Range(start=1, end=8),
            Range(start=16, end=16),
            Range(start=9, end=32),
            Range(start=64, end=64),
76
            Range(start=128, end=128),
77
78
79
80
81
82
83
            Range(start=33, end=8192),
        ]
    )
    torch.set_default_device("cuda")
    vllm_config = VllmConfig(
        scheduler_config=SchedulerConfig(
            max_num_batched_tokens=8192,
84
85
            max_model_len=8192,
            is_encoder_decoder=False,
86
87
88
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
89
            compile_ranges_endpoints=[8, 32],
90
91
92
93
94
95
96
97
98
            compile_sizes=[16, 64, 128],
            inductor_compile_config={
                "post_grad_custom_post_pass": post_grad_range_checker,
            },
        ),
    )

    with set_current_vllm_config(vllm_config):
        model = TestModel(vllm_config=vllm_config, prefix="").eval()
99
        # Number of compilations: 3 compile ranges + 3 compile sizes
100
101
102
103
104
        batch_sizes = [1, 4, 16, 24, 48, 64, 8192]

        with compilation_counter.expect(
            num_graphs_seen=1,
            num_piecewise_graphs_seen=1,
105
            num_backend_compilations=6,
106
107
        ):
            run_model(vllm_config, model, batch_sizes)
108
        assert post_grad_range_checker.num_calls == 6
109
110
111
112


def test_compile_config_get_compile_ranges():
    compilation_config = CompilationConfig(
113
        compile_ranges_endpoints=[8, 32],
114
115
116
117
    )
    VllmConfig(
        scheduler_config=SchedulerConfig(
            max_num_batched_tokens=8192,
118
119
            max_model_len=8192,
            is_encoder_decoder=False,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        ),
        compilation_config=compilation_config,
    )
    assert compilation_config.get_compile_ranges() == [
        Range(start=1, end=8),
        Range(start=9, end=32),
        Range(start=33, end=8192),
    ]


def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache):
    # To force multiple compilations, we disable the compile cache
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    post_grad_range_checker = PostGradRangeChecker(
        ranges=[
            Range(start=1, end=8),
            Range(start=9, end=8192),
        ]
    )
    scheduler_config = SchedulerConfig(
        max_num_batched_tokens=8192,
142
143
        max_model_len=8192,
        is_encoder_decoder=False,
144
145
146
147
148
149
150
151
    )
    torch.set_default_device("cuda")

    def create_vllm_config():
        return VllmConfig(
            scheduler_config=scheduler_config,
            compilation_config=CompilationConfig(
                mode=CompilationMode.VLLM_COMPILE,
152
                compile_ranges_endpoints=[8],
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
                inductor_compile_config={
                    "post_grad_custom_post_pass": post_grad_range_checker,
                },
            ),
        )

    vllm_config_1 = create_vllm_config()
    with set_current_vllm_config(vllm_config_1):
        model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval()
        batch_sizes = [1, 16]
        run_model(vllm_config_1, model1, batch_sizes)
        assert post_grad_range_checker.num_calls == 2

    post_grad_range_checker.num_calls = 0
    # Create a new vllm config with the new pass context
    vllm_config_2 = create_vllm_config()
    with set_current_vllm_config(vllm_config_2):
        model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval()
        batch_sizes = [4, 32]
        run_model(vllm_config_2, model2, batch_sizes)
        # Check that cache is used, so the number of calls
        # should be 0
        assert post_grad_range_checker.num_calls == 0