logger.py 881 Bytes
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
import logging
2

mandoxzhang's avatar
mandoxzhang committed
3
4
import torch.distributed as dist

5
6
7
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
mandoxzhang's avatar
mandoxzhang committed
8
9
10
logger = logging.getLogger(__name__)


11
class Logger:
mandoxzhang's avatar
mandoxzhang committed
12
13
14
15
16
17
18
19
20
21
22
23
    def __init__(self, log_path, cuda=False, debug=False):
        self.logger = logging.getLogger(__name__)
        self.cuda = cuda
        self.log_path = log_path
        self.debug = debug

    def info(self, message, log_=True, print_=True, *args, **kwargs):
        if (self.cuda and dist.get_rank() == 0) or not self.cuda:
            if print_:
                self.logger.info(message, *args, **kwargs)

            if log_:
24
25
                with open(self.log_path, "a+") as f_log:
                    f_log.write(message + "\n")
mandoxzhang's avatar
mandoxzhang committed
26
27
28

    def error(self, message, *args, **kwargs):
        self.logger.error(message, *args, **kwargs)