stdout_logging.py 2.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench stdout logging module."""

import sys


class StdLogger:
    """Logger class to enable or disable to redirect STDOUT and STDERR to file."""
    class StdoutLoggerStream:
        """StdoutLoggerStream class which redirect the sys.stdout to file."""
        def __init__(self, filename, rank):
            """Init the class with filename.

            Args:
                filename (str): the path of the file to save the log
                rank (int): the rank id
            """
            self._terminal = sys.stdout
            self._rank = rank
            self._log_file_handler = open(filename, 'a')

        def __getattr__(self, attr):
            """Override __getattr__.

            Args:
                attr (str): Attribute name.

            Returns:
                Any: Attribute value.
            """
            return getattr(self._terminal, attr)

        def write(self, message):
            """Write the message to the stream.

            Args:
                message (str): the message to log.
            """
            message = f'[{self._rank}]: {message}'
            self._terminal.write(message)
            self._log_file_handler.write(message)
            self._log_file_handler.flush()

        def flush(self):
            """Override flush."""
            pass

        def restore(self):
            """Restore sys.stdout and close the file."""
            self._log_file_handler.close()
            sys.stdout = self._terminal

    def add_file_handler(self, filename):
        """Init the class with filename.

        Args:
            filename (str): the path of file to save the log
        """
        self.filename = filename

    def __init__(self):
        """Init the logger."""
        self.logger_stream = None

    def start(self, rank):
        """Start the logger to redirect the sys.stdout to file.

        Args:
            rank (int): the rank id
        """
        self.logger_stream = self.StdoutLoggerStream(self.filename, rank)
        sys.stdout = self.logger_stream
        sys.stderr = sys.stdout

    def stop(self):
        """Restore the sys.stdout to termital."""
        if self.logger_stream is not None:
            self.logger_stream.restore()

    def log(self, message):
        """Write the message into the logger.

        Args:
            message (str): the message to log.
        """
        if self.logger_stream:
            self.logger_stream.write(message)
        else:
            sys.stdout.write(message)


stdout_logger = StdLogger()