draw_summary.py 5.32 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import argparse
from collections import defaultdict, OrderedDict
import matplotlib.pyplot as plt
import numpy as np

def smooth_moving_average(x, n):
    fil = np.ones(n)/n
    smoothed = np.convolve(x, fil, mode='valid')
    smoothed = np.concatenate((x[:n-1], smoothed), axis=0)
    
    return smoothed

def moving_stdev(x, n):
    fil = np.ones(n)/n
    avg_sqare = np.convolve(np.power(x, 2), fil, mode='valid')
    squared_avg = np.power(np.convolve(x, fil, mode='valid'), 2)
    var = avg_sqare - squared_avg
    stdev = np.sqrt(var)
    #pad first few values
    stdev = np.concatenate(([0]*(n-1), stdev), axis=0)
    
    return stdev

def get_plot(log):
    steps = [x[0] for x in log if isinstance(x[0], int)]
    values = [x[2] for x in log if isinstance(x[0], int)]
    return steps, values

def highlight_max_point(plot, color):
    point = max(zip(*plot), key=lambda x: x[1])
    plt.plot(point[0], point[1], 'bo-', color=color)
    plt.annotate("{:.2f}".format(point[1]), point)
    return point

def main(args):
    jlog = defaultdict(list)
    jlog['parameters'] = {}

    with open(args.log_file, 'r') as f:
        for line in f.readlines():
            line_dict = json.loads(line[5:])
            if line_dict['type'] == 'LOG':
                if line_dict['step'] == 'PARAMETER':
                    jlog['parameters'].update(line_dict['data'])
                elif line_dict['step'] == [] and 'training_summary' not in jlog:
                    jlog['training_summary']=line_dict['data']
                else:
                    for k, v in line_dict['data'].items():
                        jlog[k].append((line_dict['step'], line_dict['elapsedtime'], v))

    fig, ax1 = plt.subplots(figsize=(20,5))
    fig.suptitle(args.title, fontsize=16)
    ax1.set_xlabel('steps')
    ax1.set_ylabel('loss')

    # Define colors for specific curves
    VAL_LOSS_COLOR = 'blue'
    VAL_BLEU_COLOR = 'red'
    TEST_BLEU_COLOR = 'pink'

    # Plot smoothed loss curve
    steps, loss = get_plot(jlog['loss'])
    smoothed_loss = smooth_moving_average(loss, 150)
    stdev = moving_stdev(loss, 150)

    ax1.plot(steps, smoothed_loss, label='Training loss')
    ax1.plot(steps, smoothed_loss + stdev, '--', color='orange', linewidth=0.3, label='Stdev')
    ax1.plot(steps, smoothed_loss - stdev, '--', color='orange', linewidth=0.3)

    # Plot validation loss curve
    val_steps, val_loss = get_plot(jlog['val_loss'])
    ax1.plot(val_steps, val_loss, color='blue', label='Validation loss')

    min_val_loss_step = val_steps[np.argmin(val_loss)]
    ax1.axvline(min_val_loss_step, linestyle='dashed', color=VAL_LOSS_COLOR, linewidth=0.5, label='Validation loss minimum')

    # Plot BLEU curves
    ax2 = ax1.twinx()
    ax2.set_ylabel('BLEU')
    val_steps, val_bleu = get_plot(jlog['val_bleu'])
    ax2.plot(val_steps, val_bleu, color=VAL_BLEU_COLOR, label='Validation BLEU')
    mvb_step, _ =highlight_max_point((val_steps,val_bleu), color=VAL_BLEU_COLOR)

    # values to be labeled on plot
    max_val_bleu_step = val_steps[np.argmax(val_bleu)]
    max_val_bleu = val_bleu[val_steps.index(max_val_bleu_step)]
    min_loss_bleu = val_bleu[val_steps.index(min_val_loss_step)]


    if 'test_bleu' in jlog:
        test_steps, test_bleu = get_plot(jlog['test_bleu'])
        ax2.plot(val_steps, test_bleu, color=TEST_BLEU_COLOR, label='Test BLEU')
        highlight_max_point((test_steps, test_bleu), color=TEST_BLEU_COLOR)
    ax2.tick_params(axis='y')

    # Annotate points with highest BLEU score as well as those for minimal validation loss
    ax2.plot(min_val_loss_step, min_loss_bleu, 'bo-', color=VAL_BLEU_COLOR)
    ax2.annotate("{:.2f}".format(min_loss_bleu), (min_val_loss_step, min_loss_bleu))

    if 'test_bleu' in jlog:
        min_loss_test_bleu = test_bleu[val_steps.index(min_val_loss_step)] #BLEU score on test set when validation loss is minimal
        ax2.plot(min_val_loss_step, min_loss_test_bleu, 'bo-', color=TEST_BLEU_COLOR)
        ax2.annotate("{:.2f}".format(min_loss_test_bleu), (min_val_loss_step, min_loss_test_bleu))

        max_val_bleu_test = test_bleu[val_steps.index(max_val_bleu_step)] #BLEU score on test set when BLEU score on dev set is maximal
        ax2.plot(mvb_step, max_val_bleu_test, 'bo-', color=TEST_BLEU_COLOR)
        ax2.annotate("{:.2f}".format(max_val_bleu_test), (max_val_bleu_step, max_val_bleu_test))

    ax1.legend(loc='lower left', bbox_to_anchor=(1,0))
    ax2.legend(loc='upper left', bbox_to_anchor=(1,1))
    plt.grid()
    plt.savefig(args.output)

    # Produce json with training summary
    if args.dump_json:
        summary = OrderedDict()
        summary['args'] = OrderedDict(jlog['parameters'])
        summary['min_val_loss'] = min(val_loss)
        summary['max_val_bleu'] = max(val_bleu)
        summary['max_test_bleu'] = max(test_bleu)
        summary['final_values'] = jlog['training_summary']
        summary['avg_epoch_loss'] = [x.mean() for x in np.array_split(np.array(loss), jlog['parameters']['max_epoch'])]
        summary['min_val_loss_step'] = min_val_loss_step
        json.dump(summary, open(args.dump_json, 'w'))

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--title', type=str)
    parser.add_argument('--log-file', type=str)
    parser.add_argument('--output' ,'-o', type=str)
    parser.add_argument('--dump-json', '-j', type=str)
    args = parser.parse_args()
    main(args)