Unverified Commit 338f896a authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

massive training time improvement (#431)

* massive training time improvement

* flake8

* improve validation also

* improve training-benchmark.py

* fix devices

* send num_atoms to GPU

* rerun
parent a6d7fba3
...@@ -86,7 +86,6 @@ batch_size = 2560 ...@@ -86,7 +86,6 @@ batch_size = 2560
training, validation = torchani.data.load_ani_dataset( training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, rm_outlier=True, device=device, dspath, species_to_tensor, batch_size, rm_outlier=True, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None]) transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
############################################################################### ###############################################################################
...@@ -283,10 +282,13 @@ def validate(): ...@@ -283,10 +282,13 @@ def validate():
for batch_x, batch_y in validation: for batch_x, batch_y in validation:
true_energies = batch_y['energies'] true_energies = batch_y['energies']
predicted_energies = [] predicted_energies = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
chunk_energies = model((chunk_species, chunk_coordinates)).energies atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
predicted_energies.append(chunk_energies) atomic_properties.append(atomic_chunk)
predicted_energies = torch.cat(predicted_energies)
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -338,14 +340,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -338,14 +340,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
true_energies = batch_y['energies'] true_energies = batch_y['energies']
predicted_energies = [] predicted_energies = []
num_atoms = [] num_atoms = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1)) num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies) atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
num_atoms = torch.cat(num_atoms) num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW.zero_grad() AdamW.zero_grad()
......
...@@ -163,16 +163,19 @@ if __name__ == "__main__": ...@@ -163,16 +163,19 @@ if __name__ == "__main__":
true_energies = batch_y['energies'].to(parser.device) true_energies = batch_y['energies'].to(parser.device)
predicted_energies = [] predicted_energies = []
num_atoms = [] num_atoms = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
chunk_species = chunk_species.to(parser.device) chunk_species = chunk_species.to(parser.device)
chunk_coordinates = chunk_coordinates.to(parser.device) chunk_coordiantes = chunk_coordinates.to(parser.device)
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1)) num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies) atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies.to(true_energies.dtype)
num_atoms = torch.cat(num_atoms) num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies).to(true_energies.dtype)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward() loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment