profiler.py 2.97 KB
Newer Older
PengGao's avatar
PengGao committed
1
import asyncio
PengGao's avatar
PengGao committed
2
import time
PengGao's avatar
PengGao committed
3
from functools import wraps
PengGao's avatar
PengGao committed
4
5

import torch
helloyongyang's avatar
helloyongyang committed
6
import torch.distributed as dist
root's avatar
root committed
7
from loguru import logger
8

PengGao's avatar
PengGao committed
9
10
from lightx2v.utils.envs import *

11

PengGao's avatar
PengGao committed
12
class _ProfilingContext:
13
14
    def __init__(self, name):
        self.name = name
15
        self.rank_info = ""
helloyongyang's avatar
helloyongyang committed
16
17
18
19
        if dist.is_initialized():
            self.rank_info = f"Rank {dist.get_rank()}"
        else:
            self.rank_info = "Single GPU"
20
21
22

    def __enter__(self):
        torch.cuda.synchronize()
23
24
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
25
26
27
28
29
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
30
31
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
helloyongyang's avatar
helloyongyang committed
32
            logger.info(f"[Profile] {self.rank_info} - {self.name} Peak Memory: {peak_memory:.2f} GB")
33
        else:
helloyongyang's avatar
helloyongyang committed
34
            logger.info(f"[Profile] {self.rank_info} - {self.name} executed without GPU.")
35
        elapsed = time.perf_counter() - self.start_time
helloyongyang's avatar
helloyongyang committed
36
        logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
37
38
        return False

PengGao's avatar
PengGao committed
39
40
41
42
43
44
    async def __aenter__(self):
        torch.cuda.synchronize()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
        self.start_time = time.perf_counter()
        return self
45

PengGao's avatar
PengGao committed
46
47
48
49
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
helloyongyang's avatar
helloyongyang committed
50
            logger.info(f"[Profile] {self.rank_info} - {self.name} Peak Memory: {peak_memory:.2f} GB")
PengGao's avatar
PengGao committed
51
        else:
helloyongyang's avatar
helloyongyang committed
52
            logger.info(f"[Profile] {self.rank_info} - {self.name} executed without GPU.")
PengGao's avatar
PengGao committed
53
        elapsed = time.perf_counter() - self.start_time
helloyongyang's avatar
helloyongyang committed
54
        logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
PengGao's avatar
PengGao committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        return False

    def __call__(self, func):
        if asyncio.iscoroutinefunction(func):

            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                async with self:
                    return await func(*args, **kwargs)

            return async_wrapper
        else:

            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                with self:
                    return func(*args, **kwargs)

            return sync_wrapper


class _NullContext:
77
78
79
80
81
82
83
84
85
86
    # Context manager without decision branch logic overhead
    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        return False

PengGao's avatar
PengGao committed
87
88
89
90
91
92
93
94
95
    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        return False

    def __call__(self, func):
        return func

96
97

ProfilingContext = _ProfilingContext
helloyongyang's avatar
helloyongyang committed
98
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext