logger.py 5.72 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import atexit
import functools
import logging
import os
import re
import sys

from accelerate.logging import get_logger
from fvcore.common.file_io import PathManager
from termcolor import colored


def create_logger(output_dir=None, dist_rank=0):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.propagate = False

    fmt = "[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s"
    color_fmt = colored("[%(asctime)s %(name)s]", "green")
    color_fmt += colored("(%(filename)s %(lineno)d)", "yellow")
    color_fmt += ": %(levelname)s %(message)s"

    # create console handlers for master process
    if dist_rank == 0:
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(logging.DEBUG)
        console_handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S"))
        logger.addHandler(console_handler)

        # create file handlers
        if output_dir:
            file_handler = logging.FileHandler(os.path.join(output_dir, "training.log"), mode="a")
            file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S"))
            logger.addHandler(file_handler)

    return logger


class _ColorfulFormatter(logging.Formatter):
    def __init__(self, *args, **kwargs):
        self._root_name = kwargs.pop("root_name") + "."
        self._abbrev_name = kwargs.pop("abbrev_name", "")
        if len(self._abbrev_name):
            self._abbrev_name = self._abbrev_name + "."
        super(_ColorfulFormatter, self).__init__(*args, **kwargs)

    def formatMessage(self, record):
        record.name = record.name.replace(self._root_name, self._abbrev_name)
        log = super(_ColorfulFormatter, self).formatMessage(record)
        if record.levelno == logging.WARNING:
            prefix = colored("WARNING", "red", attrs=["blink"])
        elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
            prefix = colored("ERROR", "red", attrs=["blink", "underline"])
        else:
            return log
        return prefix + " " + log


class ColorFilter(logging.Filter):
    def filter(self, record):
        message = record.getMessage()
        # matching colored patterns
        pattern = re.compile(r'\x1b\[[0-9;]*m')
        if pattern.search(message):
            record.msg = pattern.sub('', message)
        return True


@functools.lru_cache()  # so that calling setup_logger multiple times won't add many handlers
def setup_logger(
    output=None,
    distributed_rank=0,
    *,
    color=True,
    name="detection",
    abbrev_name=None,
    enable_propagation: bool = False,
    configure_stdout: bool = True,
):
    """Initialize the detection logger and set its verbosity level to "DEBUG"

    :param output: a file name or a directory to save log. If None, will not save log file.
        If ends with ".txt" or ".log", assumed to be a file name, defaults to None
    :param distributed_rank: rank number id in distributed training, defaults to 0
    :param color: whether to show colored logging information, defaults to True
    :param name: the root module name of this logger, defaults to "detection"
    :param abbrev_name: an abbreviation of the module, to avoid long names in logs.
        Set to "" to not log the root module in logs. By default, will abbreviate "detection"
        to "det" and leave other modules unchanged, defaults to None
    :param enable_propagation: whether to propogate logs to the parent logger, defaults to False
    :param configure_stdout: whether to configure logging to stdout, defaults to True
    """
    logger_adapter = get_logger(name, "DEBUG")
    logger = logger_adapter.logger
    logger.propagate = enable_propagation

    if abbrev_name is None:
        abbrev_name = name.replace(os.path.basename(os.getcwd()), "det")

    plain_formatter = logging.Formatter(
        "[%(asctime)s %(name)s] %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
    )
    # stdout logging: master only
    if configure_stdout and distributed_rank == 0:
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.DEBUG)
        if color:
            formatter = _ColorfulFormatter(
                colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
                datefmt="%Y-%m-%d %H:%M:%S",
                root_name=name,
                abbrev_name=str(abbrev_name),
            )
        else:
            formatter = plain_formatter
        ch.setFormatter(formatter)
        logger.addHandler(ch)

    # file logging: all workers
    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "log.log")
        if distributed_rank > 0:
            filename = filename.replace(".", "_rank{}".format(distributed_rank) + ".")
        os.makedirs(os.path.dirname(filename), exist_ok=True)

        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.addFilter(ColorFilter())
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)

    return logger_adapter


# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
    # use 1K buffer if writing to cloud storage
    io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename))
    atexit.register(io.close)
    return io


def _get_log_stream_buffer_size(filename: str) -> int:
    if "://" not in filename:
        # Local file, no extra caching is necessary
        return -1
    # Remote file requires a larger cache to avoid many small writes.
    return 1024 * 1024