profiler.py 4.1 KB
Newer Older
PengGao's avatar
PengGao committed
1
import asyncio
PengGao's avatar
PengGao committed
2
import time
PengGao's avatar
PengGao committed
3
from functools import wraps
PengGao's avatar
PengGao committed
4
5

import torch
helloyongyang's avatar
helloyongyang committed
6
import torch.distributed as dist
root's avatar
root committed
7
from loguru import logger
8

PengGao's avatar
PengGao committed
9
from lightx2v.utils.envs import *
10
11
12
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)
PengGao's avatar
PengGao committed
13

14

PengGao's avatar
PengGao committed
15
class _ProfilingContext:
yihuiwen's avatar
yihuiwen committed
16
17
18
19
20
21
    def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
        """
        recorder_mode = 0: disable recorder
        recorder_mode = 1: enable recorder
        recorder_mode = 2: enable recorder and force disable logger
        """
22
        self.name = name
helloyongyang's avatar
helloyongyang committed
23
24
25
26
        if dist.is_initialized():
            self.rank_info = f"Rank {dist.get_rank()}"
        else:
            self.rank_info = "Single GPU"
yihuiwen's avatar
yihuiwen committed
27
28
29
30
        self.enable_recorder = recorder_mode > 0
        self.enable_logger = recorder_mode <= 1
        self.metrics_func = metrics_func
        self.metrics_labels = metrics_labels
31
32

    def __enter__(self):
33
        torch_device_module.synchronize()
34
35
36
37
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
38
        torch_device_module.synchronize()
39
        elapsed = time.perf_counter() - self.start_time
yihuiwen's avatar
yihuiwen committed
40
41
        if self.enable_recorder and self.metrics_func:
            if self.metrics_labels:
42
                self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
yihuiwen's avatar
yihuiwen committed
43
            else:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
44
                self.metrics_func.observe(elapsed)
yihuiwen's avatar
yihuiwen committed
45
46
        if self.enable_logger:
            logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
47
48
        return False

PengGao's avatar
PengGao committed
49
    async def __aenter__(self):
50
        torch_device_module.synchronize()
PengGao's avatar
PengGao committed
51
52
        self.start_time = time.perf_counter()
        return self
53

PengGao's avatar
PengGao committed
54
    async def __aexit__(self, exc_type, exc_val, exc_tb):
55
        torch_device_module.synchronize()
PengGao's avatar
PengGao committed
56
        elapsed = time.perf_counter() - self.start_time
yihuiwen's avatar
yihuiwen committed
57
58
        if self.enable_recorder and self.metrics_func:
            if self.metrics_labels:
59
                self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
yihuiwen's avatar
yihuiwen committed
60
            else:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
61
                self.metrics_func.observe(elapsed)
yihuiwen's avatar
yihuiwen committed
62
63
        if self.enable_logger:
            logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
PengGao's avatar
PengGao committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        return False

    def __call__(self, func):
        if asyncio.iscoroutinefunction(func):

            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                async with self:
                    return await func(*args, **kwargs)

            return async_wrapper
        else:

            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                with self:
                    return func(*args, **kwargs)

            return sync_wrapper


class _NullContext:
86
87
88
89
90
91
92
93
94
95
    # Context manager without decision branch logic overhead
    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        return False

PengGao's avatar
PengGao committed
96
97
98
99
100
101
102
103
104
    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        return False

    def __call__(self, func):
        return func

105

106
107
108
class _ProfilingContextL1(_ProfilingContext):
    """Level 1 profiling context with Level1_Log prefix."""

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
109
110
    def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
        super().__init__(f"Level1_Log {name}", recorder_mode, metrics_func, metrics_labels)
111
112
113
114
115


class _ProfilingContextL2(_ProfilingContext):
    """Level 2 profiling context with Level2_Log prefix."""

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
116
117
    def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
        super().__init__(f"Level2_Log {name}", recorder_mode, metrics_func, metrics_labels)
118
119
120
121
122
123
124
125
126


"""
PROFILING_DEBUG_LEVEL=0: [Default] disable all profiling
PROFILING_DEBUG_LEVEL=1: enable ProfilingContext4DebugL1
PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4DebugL2
"""
ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext  # if user >= 1, enable profiling
ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext  # if user >= 2, enable profiling