Unverified Commit 5bb66915 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

fix sort_by_species (#57)

parent e5439f3d
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -21,6 +21,12 @@ class TestEnergies(unittest.TestCase):
nnp, shift_energy)
def _test_molecule(self, coordinates, species, energies):
# generate a random permute
atoms = len(species)
randperm = torch.randperm(atoms)
coordinates = coordinates.index_select(1, randperm)
species = [species[i] for i in randperm.tolist()]
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_.squeeze()).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -18,6 +18,13 @@ class TestForce(unittest.TestCase):
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, forces):
# generate a random permute
atoms = len(species)
randperm = torch.randperm(atoms)
coordinates = coordinates.index_select(1, randperm)
forces = forces.index_select(1, randperm)
species = [species[i] for i in randperm.tolist()]
coordinates = torch.tensor(coordinates, requires_grad=True)
_, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment