Unverified Commit 4dcd6ab0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use a seperate optimizer to pretrain, and pretrain more (#226)

parent 2ad4126f
......@@ -234,15 +234,44 @@ optimizer = torchani.optim.AdamW([
latest_checkpoint = 'latest.pt'
pretrained = os.path.isfile(latest_checkpoint)
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint
# helper function to convert energy unit from Hartree to kcal/mol
def hartree2kcal(x):
return 627.509 * x
def validate():
# run validation
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
for batch_x, batch_y in validation:
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
return hartree2kcal(math.sqrt(total_mse / count))
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_epoches = 10
pretrain_criterion = 10 # kcal/mol
mse = torch.nn.MSELoss(reduction='none')
if not pretrained:
print("pre-training...")
epoch = 0
for _ in range(pretrain_epoches):
rmse = math.inf
pretrain_optimizer = torch.optim.Adam(nn.parameters())
while rmse > pretrain_criterion:
for batch_x, batch_y in tqdm.tqdm(training):
true_energies = batch_y['energies']
predicted_energies = []
......@@ -254,9 +283,11 @@ if not pretrained:
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
optimizer.zero_grad()
pretrain_optimizer.zero_grad()
loss.backward()
optimizer.step()
rmse = validate()
print('RMSE:', rmse, 'Target RMSE:', pretrain_criterion)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
......@@ -278,32 +309,6 @@ optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler'])
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint
# helper function to convert energy unit from Hartree to kcal/mol
def hartree2kcal(x):
return 627.509 * x
def validate():
# run validation
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
for batch_x, batch_y in validation:
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
return hartree2kcal(math.sqrt(total_mse / count))
###############################################################################
# Finally, we come to the training loop.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment