test_wrapper.py 2.06 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
from typing import Optional

import torch

youkaichao's avatar
youkaichao committed
8
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
9
from vllm.config import CompilationLevel
10
11
12
13
14
15
16
17
18


class MyMod(torch.nn.Module):
    def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
        if cache is not None:
            return x + cache
        return x * 2


youkaichao's avatar
youkaichao committed
19
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
20
21
22
    def __init__(self, model):
        self.model = model
        compiled_callable = torch.compile(self.forward, backend="eager")
23
24
25
        super().__init__(
            compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE
        )
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
        # this is the function to be compiled
        return self.model(x, cache)

    def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
        # let torch.compile compile twice
        if len(self.compiled_codes) == 2:
            dispatch_id = 0 if cache is None else 1
            with self.dispatch_to_code(dispatch_id):
                return self.forward(x, cache)
        else:
            return self.compiled_callable(x, cache)


def test_torch_compile_wrapper():
    mod = MyMod()
    wrappers = []
    for i in range(3):
        torch._dynamo.reset()
        wrapper = MyWrapper(mod)
        wrappers.append(wrapper)
        x = torch.tensor([1])
        wrapper(x, None)  # profile run, compile
        # create a cache tensor
        cache = torch.tensor([2])
        wrapper(x, cache)  # warm up with cache, recompile

        # for new input, dispatch to the compiled code directly
        new_x = torch.tensor([3])
56
57
        assert wrapper(new_x, None).item() == 6  # dispatch to the first compiled code
        assert wrapper(new_x, cache).item() == 5  # dispatch to the second compiled code
58
59
60
61

    for wrapper in wrappers:
        # make sure they have independent compiled codes
        assert len(wrapper.compiled_codes) == 2