Unverified Commit d7302cc3 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

deactivate autograd during validation (#560)

parent d3847898
...@@ -294,6 +294,8 @@ def validate(): ...@@ -294,6 +294,8 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
model.train(False)
with torch.no_grad():
for properties in validation: for properties in validation:
species = properties['species'].to(device) species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float() coordinates = properties['coordinates'].to(device).float()
...@@ -301,6 +303,7 @@ def validate(): ...@@ -301,6 +303,7 @@ def validate():
_, predicted_energies = model((species, coordinates)) _, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
......
...@@ -208,6 +208,8 @@ def validate(): ...@@ -208,6 +208,8 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
model.train(False)
with torch.no_grad():
for properties in validation: for properties in validation:
species = properties['species'].to(device) species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float() coordinates = properties['coordinates'].to(device).float()
...@@ -215,6 +217,7 @@ def validate(): ...@@ -215,6 +217,7 @@ def validate():
_, predicted_energies = model((species, coordinates)) _, predicted_energies = model((species, coordinates))
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
model.train(True)
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment