Unverified Commit d7302cc3 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

deactivate autograd during validation (#560)

parent d3847898
......@@ -294,13 +294,16 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(False)
with torch.no_grad():
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count))
......
......@@ -208,13 +208,16 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(False)
with torch.no_grad():
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment