monitor.py 2.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import time

6
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
7
8
9
10
from vllm.logger import init_logger

logger = init_logger(__name__)

11
context_manager = None
12
13
torch_compile_start_time: float = 0.0

14

15
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
16
    global torch_compile_start_time
17
    torch_compile_start_time = time.perf_counter()
18

19
    compilation_config: CompilationConfig = vllm_config.compilation_config
20
    path = vllm_config.compile_debug_dump_path()
21
    if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
22
        import depyf
23

24
        path.mkdir(parents=True, exist_ok=True)
25
        logger.debug("Dumping depyf output to %s", path)
26
        global context_manager
27
        context_manager = depyf.prepare_debug(path.as_posix())
28
        context_manager.__enter__()
29

30

31
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
32
    compilation_config: CompilationConfig = vllm_config.compilation_config
33
    total_compile_time: float = time.perf_counter() - torch_compile_start_time
34
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
35
        logger.info_once(
36
            "torch.compile and initial profiling run took %.2f s in total",
37
            total_compile_time,
38
            scope="local",
39
        )
40
41
42
43
        global context_manager
        if context_manager is not None:
            context_manager.__exit__(None, None, None)
            context_manager = None
44
45
46
47
48


cudagraph_capturing_enabled: bool = True


49
def validate_cudagraph_capturing_enabled() -> None:
50
    # used to monitor whether a cudagraph capturing is legal at runtime.
51
52
53
54
    # should be called before any cudagraph capturing.
    # if an illegal cudagraph capturing happens, raise an error.
    global cudagraph_capturing_enabled
    if not cudagraph_capturing_enabled:
55
56
57
58
        raise RuntimeError(
            "CUDA graph capturing detected at an inappropriate "
            "time. This operation is currently disabled."
        )
59
60


61
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
62
63
    global cudagraph_capturing_enabled
    cudagraph_capturing_enabled = enabled