profiler.py 2.92 KB
Newer Older
1
2
import time
import torch
PengGao's avatar
PengGao committed
3
4
import asyncio
from functools import wraps
5
from lightx2v.utils.envs import *
root's avatar
root committed
6
from loguru import logger
7
8


PengGao's avatar
PengGao committed
9
class _ProfilingContext:
10
11
    def __init__(self, name):
        self.name = name
12
13
14
15
        self.rank_info = ""
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            rank = torch.distributed.get_rank()
            self.rank_info = f"Rank {rank} - "
16
17
18

    def __enter__(self):
        torch.cuda.synchronize()
19
20
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
21
22
23
24
25
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
26
27
28
29
30
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
            logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
        else:
            logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
31
        elapsed = time.perf_counter() - self.start_time
root's avatar
root committed
32
        logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
33
34
        return False

PengGao's avatar
PengGao committed
35
36
37
38
39
40
    async def __aenter__(self):
        torch.cuda.synchronize()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
        self.start_time = time.perf_counter()
        return self
41

PengGao's avatar
PengGao committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
            logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
        else:
            logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
        elapsed = time.perf_counter() - self.start_time
        logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
        return False

    def __call__(self, func):
        if asyncio.iscoroutinefunction(func):

            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                async with self:
                    return await func(*args, **kwargs)

            return async_wrapper
        else:

            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                with self:
                    return func(*args, **kwargs)

            return sync_wrapper


class _NullContext:
73
74
75
76
77
78
79
80
81
82
    # Context manager without decision branch logic overhead
    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        return False

PengGao's avatar
PengGao committed
83
84
85
86
87
88
89
90
91
    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        return False

    def __call__(self, func):
        return func

92
93

ProfilingContext = _ProfilingContext
helloyongyang's avatar
helloyongyang committed
94
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext