monitor.py 1.31 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
5
import time

6
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
7
8
9
10
from vllm.logger import init_logger

logger = init_logger(__name__)

11
context_manager = None
12
13
torch_compile_start_time: float = 0.0

14

15
def start_monitoring_torch_compile(vllm_config: VllmConfig):
16
17
    global torch_compile_start_time
    torch_compile_start_time = time.time()
18

19
20
21
22
23
24
25
26
27
    compilation_config: CompilationConfig = vllm_config.compilation_config
    if compilation_config.level == CompilationLevel.PIECEWISE and \
        compilation_config.debug_dump_path:
        import depyf
        path = os.path.join(compilation_config.debug_dump_path,
                            f"rank_{vllm_config.parallel_config.rank}")
        global context_manager
        context_manager = depyf.prepare_debug(path)
        context_manager.__enter__()
28

29
30
31

def end_monitoring_torch_compile(vllm_config: VllmConfig):
    compilation_config: CompilationConfig = vllm_config.compilation_config
32
    if compilation_config.level == CompilationLevel.PIECEWISE:
33
        logger.info("torch.compile takes %.2f s in total",
34
                    compilation_config.compilation_time)
35
36
37
38
        global context_manager
        if context_manager is not None:
            context_manager.__exit__(None, None, None)
            context_manager = None