logs.py 1.75 KB
Newer Older
mashun1's avatar
veros  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sys
import warnings


LOGLEVELS = ("trace", "debug", "info", "warning", "error")


def _inject_proc_rank(record):
    from veros import runtime_state

    return record["extra"].update(proc_rank=runtime_state.proc_rank)


def setup_logging(loglevel="info", stream_sink=sys.stdout, log_all_processes=False):
    from loguru import logger

    handler_conf = dict(
        sink=stream_sink,
        level=loglevel.upper(),
        colorize=sys.stdout.isatty(),
    )

    if not hasattr(logger, "diagnostic"):
        logger.level("DIAGNOSTIC", no=45)

    logger.level("TRACE", color="<dim>")
    logger.level("DEBUG", color="<dim><cyan>")
    logger.level("INFO", color="")
    logger.level("SUCCESS", color="<dim><green>")
    logger.level("WARNING", color="<yellow>")
    logger.level("ERROR", color="<bold><red>")
    logger.level("DIAGNOSTIC", color="<bold><yellow>")
    logger.level("CRITICAL", color="<bold><red><WHITE>")

    logger = logger.patch(_inject_proc_rank)
    if log_all_processes:
        handler_conf.update(format="{extra[proc_rank]} | <level>{message}</level>")
    else:
        handler_conf.update(format="<level>{message}</level>", filter=lambda record: record["extra"]["proc_rank"] == 0)

    def diagnostic(_, message, *args, **kwargs):
        logger.opt(depth=1).log("DIAGNOSTIC", message, *args, **kwargs)

    logger.__class__.diagnostic = diagnostic

    def showwarning(message, cls, source, lineno, *args):
        logger.warning(
            "{warning}: {message} ({source}:{lineno})",
            message=message,
            warning=cls.__name__,
            source=source,
            lineno=lineno,
        )

    warnings.showwarning = showwarning

    logger.configure(handlers=[handler_conf])
    logger.enable("veros")

    return logger