logger.py 11.4 KB
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import os
import time
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional

from ase.data import atomic_numbers

import sevenn._keys as KEY
from sevenn import __version__

CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()}


class Singleton(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class Logger(metaclass=Singleton):
    SCREEN_WIDTH = 120  # half size of my screen / changed due to stress output

    def __init__(
        self, filename: Optional[str] = None, screen: bool = False, rank: int = 0
    ):
        self.rank = rank
        self._filename = filename
        if rank == 0:
            # if filename is not None:
            #    self.logfile = open(filename, 'a', buffering=1)
            self.logfile = None
            self.files = {}
            self.screen = screen
        else:
            self.logfile = None
            self.screen = False
        self.timer_dct = {}
        self.active = True

    def __enter__(self):
        if self.rank != 0:
            return self
        if self.logfile is None and self._filename is not None:
            try:
                self.logfile = open(
                    self._filename, 'a', buffering=1, encoding='utf-8'
                )
            except IOError as e:
                print(f'Failed to re-open log file {self._filename}: {e}')
                self.logfile = None
            self.files = {}
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.rank != 0:
            return self
        try:
            if self.logfile is not None:
                self.logfile.close()
                self.logfile = None
            for f in self.files.values():
                f.close()
        except IOError as e:
            print(f'Failed to close log files: {e}')
        finally:
            self.logfile = None
            self.files = {}

    def switch_file(self, new_filename: str):
        if self.rank != 0:
            return self
        if self.logfile is not None:
            raise ValueError('Current logfile is not yet closed')
        self._filename = new_filename
        return self

    def write(self, content: str):
        if self.rank != 0:
            return
        # no newline!
        if self.logfile is not None and self.active:
            self.logfile.write(content)
        if self.screen and self.active:
            print(content, end='')

    def writeline(self, content: str):
        content = content + '\n'
        self.write(content)

    def init_csv(self, filename: str, header: list):
        """
        Deprecated
        """
        if self.rank == 0:
            self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8')
            self.files[filename].write(','.join(header) + '\n')
        else:
            pass

    def append_csv(self, filename: str, content: list, decimal: int = 6):
        """
        Deprecated
        """
        if self.rank == 0:
            if filename not in self.files:
                self.files[filename] = open(filename, 'a', buffering=1)
            str_content = []
            for c in content:
                if isinstance(c, float):
                    str_content.append(f'{c:.{decimal}f}')
                else:
                    str_content.append(str(c))
            self.files[filename].write(','.join(str_content) + '\n')
        else:
            pass

    def natoms_write(self, natoms: Dict[str, Dict]):
        content = ''
        total_natom = {}
        for label, natom in natoms.items():
            content += self.format_k_v(label, natom)
            for specie, num in natom.items():
                try:
                    total_natom[specie] += num
                except KeyError:
                    total_natom[specie] = num
        content += self.format_k_v('Total, label wise', total_natom)
        content += self.format_k_v('Total', sum(total_natom.values()))
        self.write(content)

    def statistic_write(self, statistic: Dict[str, Dict]):
        content = ''
        for label, dct in statistic.items():
            if label.startswith('_'):
                continue
            if not isinstance(dct, dict):
                continue
            dct_new = {}
            for k, v in dct.items():
                if k.startswith('_'):
                    continue
                if isinstance(v, int):
                    dct_new[k] = v
                else:
                    dct_new[k] = f'{v:.3f}'
            content += self.format_k_v(label, dct_new)
        self.write(content)

    # TODO : refactoring!!!, this is not loss, rmse
    def epoch_write_specie_wise_loss(self, train_loss, valid_loss):
        lb_pad = 21
        fs = 6
        pad = 21 - fs
        ln = '-' * fs
        total_atom_type = train_loss.keys()
        content = ''

        for at in total_atom_type:
            t_F = train_loss[at]
            v_F = valid_loss[at]
            at_sym = CHEM_SYMBOLS[at]
            content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format(
                label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs
            ) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format(
                t_F=t_F, v_F=v_F, pad=pad, fs=fs
            )
            content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format(
                t_S=ln, v_S=ln, pad=pad, fs=fs
            )
            content += '\n'
        self.write(content)

    def write_full_table(
        self,
        dict_list: List[Dict],
        row_labels: List[str],
        decimal_places: int = 6,
        pad: int = 2,
    ):
        """
        Assume data_list is list of dict with same keys
        """
        assert len(dict_list) == len(row_labels)
        label_len = max(map(len, row_labels))
        # Extract the column names and create a 2D array of values
        col_names = list(dict_list[0].keys())

        values = [list(d.values()) for d in dict_list]

        # Format the numbers with the given decimal places
        formatted_values = [
            [f'{value:.{decimal_places}f}' for value in row] for row in values
        ]

        # Calculate padding lengths for each column (with extra padding)
        max_col_lengths = [
            max(len(str(value)) for value in col) + pad
            for col in zip(col_names, *formatted_values)
        ]

        # Create header row and separator
        header = ' ' * (label_len + pad) + ' '.join(
            col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths)
        )
        separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * (
            label_len + pad
        )

        # Print header and separator
        self.writeline(header)
        self.writeline(separator)

        # Print the data rows with row labels
        for row_label, row in zip(row_labels, formatted_values):
            data_row = ' '.join(
                value.rjust(pad) for value, pad in zip(row, max_col_lengths)
            )
            self.writeline(f'{row_label.ljust(label_len)}{data_row}')

    def format_k_v(self, key: Any, val: Any, write: bool = False):
        """
        key and val should be str convertible
        """
        MAX_KEY_SIZE = 20
        SEPARATOR = ', '
        EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3)
        NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5
        key = str(key)
        val = str(val)
        content = f'{key:<{MAX_KEY_SIZE}}: {val}'
        if len(content) > NEW_LINE_LEN:
            content = f'{key:<{MAX_KEY_SIZE}}: '
            # septate val by separator
            val_list = val.split(SEPARATOR)
            current_len = len(content)
            for val_compo in val_list:
                current_len += len(val_compo)
                if current_len > NEW_LINE_LEN:
                    newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}'
                    content += f'\\\n{newline_content}'
                    current_len = len(newline_content)
                else:
                    content += f'{val_compo}{SEPARATOR}'

        if content.endswith(f'{SEPARATOR}'):
            content = content[: -len(SEPARATOR)]
        content += '\n'

        if write is False:
            return content
        else:
            self.write(content)
            return ''

    def greeting(self):
        LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii'
        with open(LOGO_ASCII_FILE, 'r') as logo_f:
            logo_ascii = logo_f.read()
        content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n'
        content += f'version {__version__}, {time.ctime()}\n'
        self.write(content)
        self.write(logo_ascii)

    def bar(self):
        content = '-' * Logger.SCREEN_WIDTH + '\n'
        self.write(content)

    def print_config(
        self,
        model_config: Dict[str, Any],
        data_config: Dict[str, Any],
        train_config: Dict[str, Any],
    ):
        """
        print some important information from config
        """
        content = 'successfully read yaml config!\n\n' + 'from model configuration\n'
        for k, v in model_config.items():
            content += self.format_k_v(k, str(v))
        content += '\nfrom train configuration\n'
        for k, v in train_config.items():
            content += self.format_k_v(k, str(v))
        content += '\nfrom data configuration\n'
        for k, v in data_config.items():
            content += self.format_k_v(k, str(v))
        self.write(content)

    # TODO: This is not good make own exception
    def error(self, e: Exception):
        content = ''
        if type(e) is ValueError:
            content += 'Error occurred!\n'
            content += str(e) + '\n'
        else:
            content += 'Unknown error occurred!\n'
            content += traceback.format_exc()
        self.write(content)

    def timer_start(self, name: str):
        self.timer_dct[name] = datetime.now()

    def timer_end(self, name: str, message: str, remove: bool = True):
        """
        print f"{message}: {elapsed}"
        """
        elapsed = str(datetime.now() - self.timer_dct[name])
        # elapsed = elapsed.strftime('%H-%M-%S')
        if remove:
            del self.timer_dct[name]
        self.write(f'{message}: {elapsed[:-4]}\n')

    # TODO: print it without config
    # TODO: refactoring, readout part name :(
    def print_model_info(self, model, config):
        from functools import partial

        kv_write = partial(self.format_k_v, write=True)
        self.writeline('Irreps of features')
        kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out'))
        for i in range(config[KEY.NUM_CONVOLUTION]):
            kv_write(
                f'{i}th node',
                model.get_irreps_in(f'{i}_self_interaction_1'),
            )
        i = config[KEY.NUM_CONVOLUTION] - 1
        kv_write(
            'readout irreps',
            model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'),
        )

        num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad)
        self.writeline(f'# learnable parameters: {num_weights}\n')