utils.py 1.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
import logging
import os
import shutil
from collections import defaultdict, deque

import torch


class MetricLogger:
10
    r"""Logger for model metrics"""
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    def __init__(self, group, print_freq=1):
        self.print_freq = print_freq
        self._iter = 0
        self.data = defaultdict(lambda: deque(maxlen=self.print_freq))
        self.data["group"].append(group)

    def __setitem__(self, key, value):
        self.data[key].append(value)

    def _get_last(self):
        return {k: v[-1] for k, v in self.data.items()}

    def __str__(self):
        return str(self._get_last())

    def __call__(self):
        self._iter = (self._iter + 1) % self.print_freq
        if not self._iter:
            print(self, flush=True)


def save_checkpoint(state, is_best, filename):
    r"""Save the model to a temporary file first,
    then copy it to filename, in case the signal interrupts
    the torch.save() process.
    """

    if filename == "":
        return

    tempfile = filename + ".temp"

    # Remove tempfile in case interuption during the copying from tempfile to filename
    if os.path.isfile(tempfile):
        os.remove(tempfile)

    torch.save(state, tempfile)
    if os.path.isfile(tempfile):
        os.rename(tempfile, filename)
    if is_best:
        shutil.copyfile(filename, "model_best.pth.tar")
    logging.info("Checkpoint: saved")


def count_parameters(model):
57
    r"""Count the total number of parameters in the model"""
58
59

    return sum(p.numel() for p in model.parameters() if p.requires_grad)