profiler.py 2.15 KB
Newer Older
PengGao's avatar
PengGao committed
1
import asyncio
PengGao's avatar
PengGao committed
2
import time
PengGao's avatar
PengGao committed
3
from functools import wraps
PengGao's avatar
PengGao committed
4
5

import torch
helloyongyang's avatar
helloyongyang committed
6
import torch.distributed as dist
root's avatar
root committed
7
from loguru import logger
8

PengGao's avatar
PengGao committed
9
10
from lightx2v.utils.envs import *

11

PengGao's avatar
PengGao committed
12
class _ProfilingContext:
13
14
    def __init__(self, name):
        self.name = name
15
        self.rank_info = ""
helloyongyang's avatar
helloyongyang committed
16
17
18
19
        if dist.is_initialized():
            self.rank_info = f"Rank {dist.get_rank()}"
        else:
            self.rank_info = "Single GPU"
20
21
22
23
24
25
26
27
28

    def __enter__(self):
        torch.cuda.synchronize()
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - self.start_time
helloyongyang's avatar
helloyongyang committed
29
        logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
30
31
        return False

PengGao's avatar
PengGao committed
32
33
34
35
    async def __aenter__(self):
        torch.cuda.synchronize()
        self.start_time = time.perf_counter()
        return self
36

PengGao's avatar
PengGao committed
37
38
39
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - self.start_time
helloyongyang's avatar
helloyongyang committed
40
        logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
PengGao's avatar
PengGao committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        return False

    def __call__(self, func):
        if asyncio.iscoroutinefunction(func):

            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                async with self:
                    return await func(*args, **kwargs)

            return async_wrapper
        else:

            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                with self:
                    return func(*args, **kwargs)

            return sync_wrapper


class _NullContext:
63
64
65
66
67
68
69
70
71
72
    # Context manager without decision branch logic overhead
    def __init__(self, *args, **kwargs):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        return False

PengGao's avatar
PengGao committed
73
74
75
76
77
78
79
80
81
    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        return False

    def __call__(self, func):
        return func

82
83

ProfilingContext = _ProfilingContext
helloyongyang's avatar
helloyongyang committed
84
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext