profiler.py 2.92 KB
Newer Older
PengGao's avatar
PengGao committed
1
import asyncio
PengGao's avatar
PengGao committed
2
import time
PengGao's avatar
PengGao committed
3
from functools import wraps
PengGao's avatar
PengGao committed
4
5

import torch
root's avatar
root committed
6
from loguru import logger
7

PengGao's avatar
PengGao committed
8
9
from lightx2v.utils.envs import *

10

PengGao's avatar
PengGao committed
11
class _ProfilingContext:
12
13
    def __init__(self, name):
        self.name = name
14
15
16
17
        self.rank_info = ""
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            rank = torch.distributed.get_rank()
            self.rank_info = f"Rank {rank} - "
18
19
20

    def __enter__(self):
        torch.cuda.synchronize()
21
22
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
23
24
25
26
27
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
28
29
30
31
32
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
            logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
        else:
            logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
33
        elapsed = time.perf_counter() - self.start_time
root's avatar
root committed
34
        logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
35
36
        return False

PengGao's avatar
PengGao committed
37
38
39
40
41
42
    async def __aenter__(self):
        torch.cuda.synchronize()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
        self.start_time = time.perf_counter()
        return self
43

PengGao's avatar
PengGao committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
            logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
        else:
            logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
        elapsed = time.perf_counter() - self.start_time
        logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
        return False

    def __call__(self, func):
        if asyncio.iscoroutinefunction(func):

            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                async with self:
                    return await func(*args, **kwargs)

            return async_wrapper
        else:

            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                with self:
                    return func(*args, **kwargs)

            return sync_wrapper


class _NullContext:
75
76
77
78
79
80
81
82
83
84
    # Context manager without decision branch logic overhead
    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        return False

PengGao's avatar
PengGao committed
85
86
87
88
89
90
91
92
93
    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        return False

    def __call__(self, func):
        return func

94
95

ProfilingContext = _ProfilingContext
helloyongyang's avatar
helloyongyang committed
96
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext