Unverified Commit d7302cc3 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

deactivate autograd during validation (#560)

parent d3847898
...@@ -294,13 +294,16 @@ def validate(): ...@@ -294,13 +294,16 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for properties in validation: model.train(False)
species = properties['species'].to(device) with torch.no_grad():
coordinates = properties['coordinates'].to(device).float() for properties in validation:
true_energies = properties['energies'].to(device).float() species = properties['species'].to(device)
_, predicted_energies = model((species, coordinates)) coordinates = properties['coordinates'].to(device).float()
total_mse += mse_sum(predicted_energies, true_energies).item() true_energies = properties['energies'].to(device).float()
count += predicted_energies.shape[0] _, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
......
...@@ -208,13 +208,16 @@ def validate(): ...@@ -208,13 +208,16 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for properties in validation: model.train(False)
species = properties['species'].to(device) with torch.no_grad():
coordinates = properties['coordinates'].to(device).float() for properties in validation:
true_energies = properties['energies'].to(device).float() species = properties['species'].to(device)
_, predicted_energies = model((species, coordinates)) coordinates = properties['coordinates'].to(device).float()
total_mse += mse_sum(predicted_energies, true_energies).item() true_energies = properties['energies'].to(device).float()
count += predicted_energies.shape[0] _, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment