Unverified Commit d47d2579 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

LR scheduler for SGD (#282)

parent a2ba46e9
...@@ -597,20 +597,25 @@ if sys.version_info[0] > 2: ...@@ -597,20 +597,25 @@ if sys.version_info[0] > 2:
AdamW_optim = AdamW(self.weights, lr=self.init_lr) AdamW_optim = AdamW(self.weights, lr=self.init_lr)
SGD_optim = torch.optim.SGD(self.biases, lr=self.init_lr) SGD_optim = torch.optim.SGD(self.biases, lr=self.init_lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
AdamW_optim, AdamW_optim,
factor=0.5, factor=0.5,
patience=100, patience=100,
threshold=0) threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
SGD_optim,
factor=0.5,
patience=100,
threshold=0)
while True: while True:
rmse = self.evaluate(self.validation_set) rmse = self.evaluate(self.validation_set)
learning_rate = AdamW_optim.param_groups[0]['lr'] learning_rate = AdamW_optim.param_groups[0]['lr']
if learning_rate < self.min_lr or scheduler.last_epoch > self.nmax: if learning_rate < self.min_lr or AdamW_scheduler.last_epoch > self.nmax:
break break
# checkpoint # checkpoint
if scheduler.is_better(rmse, scheduler.best): if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
no_improve_count = 0 no_improve_count = 0
torch.save(self.nn.state_dict(), self.model_checkpoint) torch.save(self.nn.state_dict(), self.model_checkpoint)
else: else:
...@@ -619,17 +624,19 @@ if sys.version_info[0] > 2: ...@@ -619,17 +624,19 @@ if sys.version_info[0] > 2:
if no_improve_count > self.max_nonimprove: if no_improve_count > self.max_nonimprove:
break break
scheduler.step(rmse) AdamW_scheduler.step(rmse)
SGD_scheduler.step(rmse)
if self.tensorboard is not None: if self.tensorboard is not None:
self.tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch) self.tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
self.tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch) self.tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
self.tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch) self.tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, scheduler.last_epoch) self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in self.tqdm( for i, (batch_x, batch_y) in self.tqdm(
enumerate(self.training_set), enumerate(self.training_set),
total=len(self.training_set), total=len(self.training_set),
desc='epoch {}'.format(scheduler.last_epoch) desc='epoch {}'.format(AdamW_scheduler.last_epoch)
): ):
true_energies = batch_y['energies'] true_energies = batch_y['energies']
...@@ -650,12 +657,12 @@ if sys.version_info[0] > 2: ...@@ -650,12 +657,12 @@ if sys.version_info[0] > 2:
# write current batch loss to TensorBoard # write current batch loss to TensorBoard
if self.tensorboard is not None: if self.tensorboard is not None:
self.tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(self.training_set) + i) self.tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(self.training_set) + i)
# log elapsed time # log elapsed time
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
if self.tensorboard is not None: if self.tensorboard is not None:
self.tensorboard.add_scalar('time_vs_epoch', elapsed, scheduler.last_epoch) self.tensorboard.add_scalar('time_vs_epoch', elapsed, AdamW_scheduler.last_epoch)
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', 'Trainer'] __all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', 'Trainer']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment