test_lazy_torch_compile.py 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Description: Test the lazy import module
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script

import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator


@dataclasses.dataclass
class BlameResult:
    found: bool = False
    trace_stack: str = ""


@contextlib.contextmanager
def blame(func: Callable) -> Generator[BlameResult, None, None]:
    """
    Trace the function calls to find the first function that satisfies the
    condition. The trace stack will be stored in the result.

    Usage:

    ```python
    with blame(lambda: some_condition()) as result:
        # do something
    
    if result.found:
        print(result.trace_stack)
    """
    result = BlameResult()

    def _trace_calls(frame, event, arg=None):
        nonlocal result
        if event in ['call', 'return']:
            # for every function call or return
            try:
                # Temporarily disable the trace function
                sys.settrace(None)
                # check condition here
                if not result.found and func():
                    result.found = True
                    result.trace_stack = "".join(traceback.format_stack())
                # Re-enable the trace function
                sys.settrace(_trace_calls)
            except NameError:
                # modules are deleted during shutdown
                pass
        return _trace_calls

    sys.settrace(_trace_calls)

    yield result

    sys.settrace(None)


module_name = "torch._inductor.async_compile"

with blame(lambda: module_name in sys.modules) as result:
    import vllm  # noqa

assert not result.found, (f"Module {module_name} is already imported, the"
                          f" first import location is:\n{result.trace_stack}")

print(f"Module {module_name} is not imported yet")