profiler.py 4.2 KB
Newer Older
PengGao's avatar
PengGao committed
1
import asyncio
PengGao's avatar
PengGao committed
2
import time
PengGao's avatar
PengGao committed
3
from functools import wraps
PengGao's avatar
PengGao committed
4
5

import torch
helloyongyang's avatar
helloyongyang committed
6
import torch.distributed as dist
root's avatar
root committed
7
from loguru import logger
8

PengGao's avatar
PengGao committed
9
10
from lightx2v.utils.envs import *

11

PengGao's avatar
PengGao committed
12
class _ProfilingContext:
yihuiwen's avatar
yihuiwen committed
13
14
15
16
17
18
    def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
        """
        recorder_mode = 0: disable recorder
        recorder_mode = 1: enable recorder
        recorder_mode = 2: enable recorder and force disable logger
        """
19
        self.name = name
helloyongyang's avatar
helloyongyang committed
20
21
22
23
        if dist.is_initialized():
            self.rank_info = f"Rank {dist.get_rank()}"
        else:
            self.rank_info = "Single GPU"
yihuiwen's avatar
yihuiwen committed
24
25
26
27
        self.enable_recorder = recorder_mode > 0
        self.enable_logger = recorder_mode <= 1
        self.metrics_func = metrics_func
        self.metrics_labels = metrics_labels
28
29

    def __enter__(self):
Kane's avatar
Kane committed
30
        self.device_synchronize()
31
32
33
34
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
Kane's avatar
Kane committed
35
        self.device_synchronize()
36
        elapsed = time.perf_counter() - self.start_time
yihuiwen's avatar
yihuiwen committed
37
38
        if self.enable_recorder and self.metrics_func:
            if self.metrics_labels:
39
                self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
yihuiwen's avatar
yihuiwen committed
40
            else:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
41
                self.metrics_func.observe(elapsed)
yihuiwen's avatar
yihuiwen committed
42
43
        if self.enable_logger:
            logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
44
45
        return False

PengGao's avatar
PengGao committed
46
    async def __aenter__(self):
Kane's avatar
Kane committed
47
        self.device_synchronize()
PengGao's avatar
PengGao committed
48
49
        self.start_time = time.perf_counter()
        return self
50

PengGao's avatar
PengGao committed
51
    async def __aexit__(self, exc_type, exc_val, exc_tb):
Kane's avatar
Kane committed
52
        self.device_synchronize()
PengGao's avatar
PengGao committed
53
        elapsed = time.perf_counter() - self.start_time
yihuiwen's avatar
yihuiwen committed
54
55
        if self.enable_recorder and self.metrics_func:
            if self.metrics_labels:
56
                self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
yihuiwen's avatar
yihuiwen committed
57
            else:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
58
                self.metrics_func.observe(elapsed)
yihuiwen's avatar
yihuiwen committed
59
60
        if self.enable_logger:
            logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
PengGao's avatar
PengGao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        return False

    def __call__(self, func):
        if asyncio.iscoroutinefunction(func):

            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                async with self:
                    return await func(*args, **kwargs)

            return async_wrapper
        else:

            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                with self:
                    return func(*args, **kwargs)

            return sync_wrapper

Kane's avatar
Kane committed
81
82
83
84
85
86
87
88
89
    def device_synchronize(
        self,
    ):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elif hasattr(torch, "mlu") and torch.mlu.is_available():
            torch.mlu.synchronize()
        return

PengGao's avatar
PengGao committed
90
91

class _NullContext:
92
93
94
95
96
97
98
99
100
101
    # Context manager without decision branch logic overhead
    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        return False

PengGao's avatar
PengGao committed
102
103
104
105
106
107
108
109
110
    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        return False

    def __call__(self, func):
        return func

111

112
113
114
class _ProfilingContextL1(_ProfilingContext):
    """Level 1 profiling context with Level1_Log prefix."""

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
115
116
    def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
        super().__init__(f"Level1_Log {name}", recorder_mode, metrics_func, metrics_labels)
117
118
119
120
121


class _ProfilingContextL2(_ProfilingContext):
    """Level 2 profiling context with Level2_Log prefix."""

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
122
123
    def __init__(self, name, recorder_mode=0, metrics_func=None, metrics_labels=None):
        super().__init__(f"Level2_Log {name}", recorder_mode, metrics_func, metrics_labels)
124
125
126
127
128
129
130
131
132


"""
PROFILING_DEBUG_LEVEL=0: [Default] disable all profiling
PROFILING_DEBUG_LEVEL=1: enable ProfilingContext4DebugL1
PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4DebugL2
"""
ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext  # if user >= 1, enable profiling
ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext  # if user >= 2, enable profiling