Commit 41aa0f46 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Fix dtype in self_energies (#347)

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* remove float32 dtypes from comp6.py
parent e7af3cbd
...@@ -8,7 +8,6 @@ import tqdm ...@@ -8,7 +8,6 @@ import tqdm
HARTREE2KCAL = 627.509 HARTREE2KCAL = 627.509
dtype = torch.float32
# parse command line arguments # parse command line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -21,7 +20,7 @@ parser.add_argument('-d', '--device', ...@@ -21,7 +20,7 @@ parser.add_argument('-d', '--device',
parser = parser.parse_args() parser = parser.parse_args()
# run benchmark # run benchmark
ani1x = torchani.models.ANI1x().to(dtype).to(parser.device) ani1x = torchani.models.ANI1x().to(parser.device)
def recursive_h5_files(base): def recursive_h5_files(base):
...@@ -80,14 +79,11 @@ def do_benchmark(model): ...@@ -80,14 +79,11 @@ def do_benchmark(model):
rmse_averager_force = Averager() rmse_averager_force = Averager()
for i in tqdm.tqdm(dataset, position=0, desc="dataset"): for i in tqdm.tqdm(dataset, position=0, desc="dataset"):
# read # read
coordinates = torch.tensor( coordinates = torch.tensor(i['coordinates'], device=parser.device)
i['coordinates'], dtype=dtype, device=parser.device)
species = model.species_to_tensor(i['species']) \ species = model.species_to_tensor(i['species']) \
.unsqueeze(0).expand(coordinates.shape[0], -1) .unsqueeze(0).expand(coordinates.shape[0], -1)
energies = torch.tensor(i['energies'], dtype=dtype, energies = torch.tensor(i['energies'], device=parser.device)
device=parser.device) forces = torch.tensor(i['forces'], device=parser.device)
forces = torch.tensor(i['forces'], dtype=dtype,
device=parser.device)
# compute # compute
energies2, forces2 = by_batch(species, coordinates, model) energies2, forces2 = by_batch(species, coordinates, model)
ediff = energies - energies2 ediff = energies - energies2
......
...@@ -192,7 +192,7 @@ class EnergyShifter(torch.nn.Module): ...@@ -192,7 +192,7 @@ class EnergyShifter(torch.nn.Module):
intercept = self.self_energies[-1] intercept = self.self_energies[-1]
self_energies = self.self_energies[species] self_energies = self.self_energies[species]
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device) self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double)
return self_energies.sum(dim=1) + intercept return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties): def subtract_from_dataset(self, atomic_properties, properties):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment