import os import warnings import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt import numpy as np import scipy.signal from tensorflow import keras from tensorflow.keras import backend as K class LossHistory(keras.callbacks.Callback): def __init__(self, log_dir): import datetime curr_time = datetime.datetime.now() time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S') self.log_dir = log_dir self.time_str = time_str self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str)) self.losses = [] self.val_loss = [] try: os.makedirs(self.save_path) except OSError: pass def on_epoch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) self.val_loss.append(logs.get('val_loss')) with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f: f.write(str(logs.get('loss'))) f.write("\n") with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f: f.write(str(logs.get('val_loss'))) f.write("\n") self.loss_plot() def loss_plot(self): iters = range(len(self.losses)) plt.figure() plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') try: if len(self.losses) < 25: num = 5 else: num = 15 plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') except: pass plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('A Loss Curve') plt.legend(loc="upper right") plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png")) plt.cla() plt.close("all") class ExponentDecayScheduler(keras.callbacks.Callback): def __init__(self, decay_rate, verbose=0): super(ExponentDecayScheduler, self).__init__() self.decay_rate = decay_rate self.verbose = verbose self.learning_rates = [] def on_epoch_end(self, batch, logs=None): learning_rate = K.get_value(self.model.optimizer.lr) * self.decay_rate K.set_value(self.model.optimizer.lr, learning_rate) if self.verbose > 0: print('Setting learning rate to %s.' % (learning_rate)) class ModelCheckpoint(keras.callbacks.Callback): def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1): super(ModelCheckpoint, self).__init__() self.monitor = monitor self.verbose = verbose self.filepath = filepath self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.period = period self.epochs_since_last_save = 0 if mode not in ['auto', 'min', 'max']: warnings.warn('ModelCheckpoint mode %s is unknown, ' 'fallback to auto mode.' % (mode), RuntimeWarning) mode = 'auto' if mode == 'min': self.monitor_op = np.less self.best = np.Inf elif mode == 'max': self.monitor_op = np.greater self.best = -np.Inf else: if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf def on_epoch_end(self, epoch, logs=None): logs = logs or {} self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 filepath = self.filepath.format(epoch=epoch + 1, **logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: warnings.warn('Can save best model only with %s available, ' 'skipping.' % (self.monitor), RuntimeWarning) else: if self.monitor_op(current, self.best): if self.verbose > 0: print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) else: if self.verbose > 0: print('\nEpoch %05d: %s did not improve' % (epoch + 1, self.monitor)) else: if self.verbose > 0: print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True)