logger.py 3.32 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import csv
import os.path

import matplotlib

matplotlib.use('Agg')

from matplotlib import pyplot as plt
import numpy as np

plt.switch_backend('agg')


class CsvLogger:
    def __init__(self, filepath='./', filename='results.csv', data=None):
        self.log_path = filepath
        self.log_name = filename
        self.csv_path = os.path.join(self.log_path, self.log_name)
        self.fieldsnames = ['epoch', 'val_error1', 'val_error5', 'val_loss', 'train_error1', 'train_error5',
                            'train_loss']

        with open(self.csv_path, 'w') as f:
            writer = csv.DictWriter(f, fieldnames=self.fieldsnames)
            writer.writeheader()

        self.data = {}
        for field in self.fieldsnames:
            self.data[field] = []
        if data is not None:
            for d in data:
                d_num = {}
                for key in d:
                    d_num[key] = float(d[key]) if key != 'epoch' else int(d[key])
                self.write(d_num)

    def write(self, data):
        for k in self.data:
            self.data[k].append(data[k])
        with open(self.csv_path, 'a') as f:
            writer = csv.DictWriter(f, fieldnames=self.fieldsnames)
            writer.writerow(data)

    def save_params(self, args, params):
        with open(os.path.join(self.log_path, 'params.txt'), 'w') as f:
            f.write('{}\n'.format(' '.join(args)))
            f.write('{}\n'.format(params))

    def write_text(self, text, print_t=True):
        with open(os.path.join(self.log_path, 'params.txt'), 'a') as f:
            f.write('{}\n'.format(text))
        if print_t:
            print(text)

    def plot_progress_errk(self, claimed_acc=None, title='ShuffleNetv2', k=1):
        tr_str = 'train_error{}'.format(k)
        val_str = 'val_error{}'.format(k)
        plt.figure(figsize=(9, 8), dpi=300)
        plt.plot(self.data[tr_str], label='Training error')
        plt.plot(self.data[val_str], label='Validation error')
        if claimed_acc is not None:
            plt.plot((0, len(self.data[tr_str])), (1 - claimed_acc, 1 - claimed_acc), 'k--',
                     label='Claimed validation error ({:.2f}%)'.format(100. * (1 - claimed_acc)))
        plt.plot((0, len(self.data[tr_str])),
                 (np.min(self.data[val_str]), np.min(self.data[val_str])), 'r--',
                 label='Best validation error ({:.2f}%)'.format(100. * np.min(self.data[val_str])))
        plt.title('Top-{} error for {}'.format(k, title))
        plt.xlabel('Epoch')
        plt.ylabel('Error')
        plt.legend()
        plt.xlim(0, len(self.data[tr_str]) + 1)
        plt.savefig(os.path.join(self.log_path, 'top{}.png'.format(k)))

    def plot_progress_loss(self, title='ShuffleNetv2'):
        plt.figure(figsize=(9, 8), dpi=300)
        plt.plot(self.data['train_loss'], label='Training')
        plt.plot(self.data['val_loss'], label='Validation')
        plt.title(title)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.xlim(0, len(self.data['train_loss']) + 1)
        plt.savefig(os.path.join(self.log_path, 'loss.png'))

    def plot_progress(self, claimed_acc1=None, claimed_acc5=None, title='ShuffleNetv2'):
        self.plot_progress_errk(claimed_acc1, title, 1)
        self.plot_progress_errk(claimed_acc5, title, 5)
        self.plot_progress_loss(title)
        plt.close('all')