Unverified Commit 39137175 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Misc improvements (#117)

parent e4fe2a5c
...@@ -16,7 +16,7 @@ class TestEnsemble(unittest.TestCase): ...@@ -16,7 +16,7 @@ class TestEnsemble(unittest.TestCase):
def _test_molecule(self, coordinates, species): def _test_molecule(self, coordinates, species):
builtins = torchani.neurochem.Builtins() builtins = torchani.neurochem.Builtins()
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates.requires_grad_(True)
aev = builtins.aev_computer aev = builtins.aev_computer
ensemble = builtins.models ensemble = builtins.models
models = [torch.nn.Sequential(aev, m) for m in ensemble] models = [torch.nn.Sequential(aev, m) for m in ensemble]
......
...@@ -22,7 +22,7 @@ class TestForce(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestForce(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f) coordinates, species, _, _, _, forces = pickle.load(f)
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates.requires_grad_(True)
_, energies = self.model((species, coordinates)) _, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), derivative = torch.autograd.grad(energies.sum(),
coordinates)[0] coordinates)[0]
...@@ -36,7 +36,7 @@ class TestForce(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestForce(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f) coordinates, species, _, _, _, forces = pickle.load(f)
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates.requires_grad_(True)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces)) coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.utils.pad_coordinates( species, coordinates = torchani.utils.pad_coordinates(
......
...@@ -155,11 +155,7 @@ class AEVComputer(torch.nn.Module): ...@@ -155,11 +155,7 @@ class AEVComputer(torch.nn.Module):
"""Shape (conformations, atoms, atoms) storing Rij distances""" """Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1) padding_mask = (species == -1).unsqueeze(1)
distances = torch.where( distances = distances.masked_fill(padding_mask, math.inf)
padding_mask,
torch.tensor(math.inf, dtype=self.EtaR.dtype,
device=self.EtaR.device),
distances)
distances, indices = distances.sort(-1) distances, indices = distances.sort(-1)
...@@ -172,11 +168,10 @@ class AEVComputer(torch.nn.Module): ...@@ -172,11 +168,10 @@ class AEVComputer(torch.nn.Module):
radial_terms = self._radial_subaev_terms(distances) radial_terms = self._radial_subaev_terms(distances)
indices_a = indices.index_select(-1, inRca) indices_a = indices.index_select(-1, inRca)
new_shape = list(indices_a.shape) + [3]
# TODO: remove this workaround when gather support broadcasting # TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532 # https://github.com/pytorch/pytorch/pull/9532
_indices_a = indices_a.unsqueeze(-1).expand(*new_shape) _indices_a = indices_a.unsqueeze(-1).expand(-1, -1, -1, 3)
vec = vec.gather(-2, _indices_a) vec = vec.gather(-2, _indices_a)
vec = self._combinations(vec, -2) vec = self._combinations(vec, -2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment