memory_profiler.py 961 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from loguru import logger


def peak_memory_decorator(func):
    def wrapper(*args, **kwargs):
        # 检查是否在分布式环境中
        rank_info = ""
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            rank = torch.distributed.get_rank()
            rank_info = f"Rank {rank} - "

        # 如果使用GPU,重置显存统计
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

        # 执行目标函数
        result = func(*args, **kwargs)

        # 获取峰值显存
        if torch.cuda.is_available():
            peak_memory = torch.cuda.max_memory_allocated() / (1024**3)  # 转换为GB
            logger.info(f"{rank_info}Function '{func.__qualname__}' Peak Memory: {peak_memory:.2f} GB")
        else:
            logger.info(f"{rank_info}Function '{func.__qualname__}' executed without GPU.")

        return result

    return wrapper