Unverified Commit 9036b443 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Fix CUDA support in ANIModel and EnergyShifter (#341)

* fix device type

* Update utils.py

* Update utils.py

fix typo!
parent 6ee36733
...@@ -40,7 +40,7 @@ class ANIModel(torch.nn.Module): ...@@ -40,7 +40,7 @@ class ANIModel(torch.nn.Module):
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
output = torch.full(species_.shape, self.padding_fill, output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype) dtype=aev.dtype, device=species.device)
i = 0 i = 0
for m in self.module_list: for m in self.module_list:
mask = (species_ == i) mask = (species_ == i)
......
...@@ -192,7 +192,7 @@ class EnergyShifter(torch.nn.Module): ...@@ -192,7 +192,7 @@ class EnergyShifter(torch.nn.Module):
intercept = self.self_energies[-1] intercept = self.self_energies[-1]
self_energies = self.self_energies[species] self_energies = self.self_energies[species]
self_energies[species == torch.tensor(-1)] = torch.tensor(0) self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device)
return self_energies.sum(dim=1) + intercept return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties): def subtract_from_dataset(self, atomic_properties, properties):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment