Unverified Commit 338f896a authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

massive training time improvement (#431)

* massive training time improvement

* flake8

* improve validation also

* improve training-benchmark.py

* fix devices

* send num_atoms to GPU

* rerun
parent a6d7fba3
......@@ -86,7 +86,6 @@ batch_size = 2560
training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, rm_outlier=True, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('Self atomic energies: ', energy_shifter.self_energies)
###############################################################################
......@@ -283,10 +282,13 @@ def validate():
for batch_x, batch_y in validation:
true_energies = batch_y['energies']
predicted_energies = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x:
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count))
......@@ -338,14 +340,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x:
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW.zero_grad()
......
......@@ -163,16 +163,19 @@ if __name__ == "__main__":
true_energies = batch_y['energies'].to(parser.device)
predicted_energies = []
num_atoms = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x:
chunk_species = chunk_species.to(parser.device)
chunk_coordinates = chunk_coordinates.to(parser.device)
chunk_coordiantes = chunk_coordinates.to(parser.device)
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies.to(true_energies.dtype)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies).to(true_energies.dtype)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment