"tests/kernels/cache.py" did not exist on "cbf8779afafdaba2ddc6e2212d67c40f1b6e11fd"
Unverified Commit d7302cc3 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

deactivate autograd during validation (#560)

parent d3847898
......@@ -294,6 +294,8 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
model.train(False)
with torch.no_grad():
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
......@@ -301,6 +303,7 @@ def validate():
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count))
......
......@@ -208,6 +208,8 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
model.train(False)
with torch.no_grad():
for properties in validation:
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
......@@ -215,6 +217,7 @@ def validate():
_, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment