misc.py 2.94 KB
Newer Older
sunxx1's avatar
sunxx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import torch
import logging

logger = logging.getLogger()


def check_keys(model, checkpoint):
    model_keys = set(model.state_dict().keys())
    ckpt_keys = set(checkpoint['state_dict'].keys())
    missing_keys = model_keys - ckpt_keys
    for key in missing_keys:
        logger.warning('missing key in model:{}'.format(key))
    unexpected_keys = ckpt_keys - model_keys
    for key in unexpected_keys:
        logger.warning('unexpected key in checkpoint:{}'.format(key))
    shared_keys = model_keys & ckpt_keys
    for key in shared_keys:
        logger.info('shared key:{}'.format(key))
    return shared_keys


def accuracy(output, target, topk=(1, ), raw=False):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(
                0, keepdim=True)
            if raw:
                res.append(correct_k)
            else:
                res.append(correct_k.mul(100.0 / target.size(0)))
        return res


class AverageMeter(object):
    """Computes and stores the average and current value
       When length < 0 , save all history data """

    def __init__(self, name, fmt=':f', length=1):
        self.name = name
        self.fmt = fmt
        self.length = length
        self.reset()

    def reset(self):
        if self.length > 1:
            self.history = []
        elif self.length < 0:
            self.count = 0
            self.sum = 0
        self.avg = 0
        self.val = 0

    def update(self, val):
        self.val = val
        if self.length > 1:
            self.history.append(val)
            if len(self.history) > self.length:
                del self.history[0]
            self.avg = np.mean(self.history)
        elif self.length < 0:
            self.sum += val
            self.count += 1
            self.avg = self.sum / self.count

    def __str__(self):
        if self.length == 0 or self.length == 1:
            fmtstr = '{name} {val' + self.fmt + '}'
        else:
            fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):

    def __init__(self, num_batches, *meters, prefix=''):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        logger.info(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'