Unverified Commit 18e63f1c authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Update utils.py (#338)

parent 053f0863
...@@ -215,8 +215,8 @@ class EnergyShifter(torch.nn.Module): ...@@ -215,8 +215,8 @@ class EnergyShifter(torch.nn.Module):
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
""" """
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device) sae = self.sae(species).to(energies.device)
return species, energies + sae return species, energies.to(sae.dtype) + sae
class ChemicalSymbolsToInts: class ChemicalSymbolsToInts:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment