Unverified Commit e67a2ad5 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Cherry-pick roitberg-group#22 (#546)



* commit farhads and jinzes changes to fix some bugs [WIP] (#22)

* commit farhads and jinzes changes to fix some bugs

* Add tests for correct inputs

* add missing

* fix
Co-authored-by: default avatarIgnacio Pickering <ign.pickering@gmail.com>
parent b270d59d
......@@ -9,6 +9,29 @@ path = os.path.dirname(os.path.realpath(__file__))
N = 97
class TestCorrectInput(torchani.testing.TestCase):
def setUp(self):
self.model = torchani.models.ANI1x(model_index=0, periodic_table_index=False)
self.converter = torchani.nn.SpeciesConverter(['H', 'C', 'N', 'O'])
self.aev_computer = self.model.aev_computer
self.ani_model = self.model.neural_networks
def testUnknownSpecies(self):
# unsupported atomic number raises a value error
self.assertRaises(ValueError, self.converter, (torch.tensor([[1, 1, 7, 10]]), torch.zeros((1, 4, 3))))
# larger index than supported by the model raises a value error
self.assertRaises(ValueError, self.model, (torch.tensor([[0, 1, 2, 4]]), torch.zeros((1, 4, 3))))
def testIncorrectShape(self):
# non matching shapes between species and coordinates
self.assertRaises(AssertionError, self.model, (torch.tensor([[0, 1, 2, 3]]), torch.zeros((1, 3, 3))))
self.assertRaises(AssertionError, self.aev_computer, (torch.tensor([[0, 1, 2, 3]]), torch.zeros((1, 3, 3))))
self.assertRaises(AssertionError, self.ani_model, (torch.tensor([[0, 1, 2, 3]]), torch.zeros((1, 3, 384))))
self.assertRaises(AssertionError, self.model, (torch.tensor([[0, 1, 2, 3]]), torch.zeros((1, 4, 4))))
self.assertRaises(AssertionError, self.model, (torch.tensor([0, 1, 2, 3]), torch.zeros((4, 3))))
class TestEnergies(torchani.testing.TestCase):
# tests the predicions for a torchani.nn.Sequential(AEVComputer(),
# ANIModel(), EnergyShifter()) against precomputed values
......
......@@ -470,7 +470,9 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape ``(N, A, self.aev_length())``
"""
species, coordinates = input_
assert species.dim() == 2
assert species.shape == coordinates.shape[:-1]
assert coordinates.shape[-1] == 3
if cell is None and pbc is None:
aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None)
......
......@@ -55,6 +55,8 @@ class ANIModel(torch.nn.ModuleDict):
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
species, aev = species_aev
assert species.shape == aev.shape[:-1]
atomic_energies = self._atomic_energies((species, aev))
# shape of atomic energies is (C, A)
return SpeciesEnergies(species, torch.sum(atomic_energies, dim=1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment