Commit 48eeb3dc authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Scale by sqrt(natoms) instead of natoms (#259)

parent 74ea32b9
...@@ -331,7 +331,7 @@ for _ in range(scheduler.last_epoch + 1, max_epochs): ...@@ -331,7 +331,7 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
predicted_energies.append(chunk_energies) predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype) num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies) predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment