Unverified Commit b7cab4f1 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Support batching non isomers in models and aev computer (#62)

parent b63f6e40
...@@ -8,7 +8,7 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5. ...@@ -8,7 +8,7 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501 sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501 network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file) aev_computer = torchani.AEVComputer(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species) prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir, nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8) ensemble=8)
......
...@@ -18,10 +18,9 @@ def atomic(): ...@@ -18,10 +18,9 @@ def atomic():
def get_or_create_model(filename, benchmark=False, def get_or_create_model(filename, benchmark=False,
device=torch.device('cpu')): device=torch.device('cpu')):
aev_computer = torchani.SortedAEV(benchmark=benchmark) aev_computer = torchani.AEVComputer(benchmark=benchmark)
prepare = torchani.PrepareInput(aev_computer.species) prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.CustomModel( model = torchani.models.CustomModel(
reducer=torch.sum,
benchmark=benchmark, benchmark=benchmark,
per_species={ per_species={
'C': atomic(), 'C': atomic(),
......
...@@ -36,7 +36,7 @@ parser = parser.parse_args() ...@@ -36,7 +36,7 @@ parser = parser.parse_args()
# load modules and datasets # load modules and datasets
device = torch.device(parser.device) device = torch.device(parser.device)
aev_computer = torchani.SortedAEV(const_file=parser.const_file) aev_computer = torchani.AEVComputer(const_file=parser.const_file)
prepare = torchani.PrepareInput(aev_computer.species) prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, nn = torchani.models.NeuroChemNNP(aev_computer.species,
from_=parser.network_dir, from_=parser.network_dir,
......
...@@ -11,7 +11,7 @@ N = 97 ...@@ -11,7 +11,7 @@ N = 97
class TestAEV(unittest.TestCase): class TestAEV(unittest.TestCase):
def setUp(self): def setUp(self):
self.aev_computer = torchani.SortedAEV() self.aev_computer = torchani.AEVComputer()
self.radial_length = self.aev_computer.radial_length self.radial_length = self.aev_computer.radial_length
self.prepare = torchani.PrepareInput(self.aev_computer.species) self.prepare = torchani.PrepareInput(self.aev_computer.species)
self.aev = torch.nn.Sequential( self.aev = torch.nn.Sequential(
...@@ -20,20 +20,9 @@ class TestAEV(unittest.TestCase): ...@@ -20,20 +20,9 @@ class TestAEV(unittest.TestCase):
) )
self.tolerance = 1e-5 self.tolerance = 1e-5
def _test_molecule(self, coordinates, species, expected_radial, def _assertAEVEqual(self, expected_radial, expected_angular, aev):
expected_angular):
# compute aev using aev computer, sorted
_, aev = self.aev((species, coordinates))
radial = aev[..., :self.radial_length] radial = aev[..., :self.radial_length]
angular = aev[..., self.radial_length:] angular = aev[..., self.radial_length:]
# manually sort expected values
species = self.prepare.species_to_tensor(species,
self.aev_computer.EtaR.device)
_, reverse = torch.sort(species)
expected_radial = expected_radial.index_select(1, reverse)
expected_angular = expected_angular.index_select(1, reverse)
radial_diff = expected_radial - radial radial_diff = expected_radial - radial
radial_max_error = torch.max(torch.abs(radial_diff)).item() radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular angular_diff = expected_angular - angular
...@@ -41,12 +30,35 @@ class TestAEV(unittest.TestCase): ...@@ -41,12 +30,35 @@ class TestAEV(unittest.TestCase):
self.assertLess(radial_max_error, self.tolerance) self.assertLess(radial_max_error, self.tolerance)
self.assertLess(angular_max_error, self.tolerance) self.assertLess(angular_max_error, self.tolerance)
def testGDB(self): def testIsomers(self):
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _ \
= pickle.load(f)
_, aev = self.aev((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev)
def testPadding(self):
species_coordinates = []
radial_angular = []
for i in range(N): for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, radial, angular, _, _ = pickle.load(f) coordinates, species, radial, angular, _, _ = pickle.load(f)
self._test_molecule(coordinates, species, radial, angular) species_coordinates.append(
self.prepare((species, coordinates)))
radial_angular.append((radial, angular))
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
_, aev = self.aev_computer((species, coordinates))
start = 0
for expected_radial, expected_angular in radial_angular:
conformations = expected_radial.shape[0]
atoms = expected_radial.shape[1]
aev_ = aev[start:start+conformations, 0:atoms]
start += conformations
self._assertAEVEqual(expected_radial, expected_angular, aev_)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -78,7 +78,7 @@ class TestBenchmark(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestBenchmark(unittest.TestCase):
self.assertEqual(result_module.timers[i], 0) self.assertEqual(result_module.timers[i], 0)
def testAEV(self): def testAEV(self):
aev_computer = torchani.SortedAEV(benchmark=True) aev_computer = torchani.AEVComputer(benchmark=True)
prepare = torchani.PrepareInput(aev_computer.species) prepare = torchani.PrepareInput(aev_computer.species)
run_module = torch.nn.Sequential(prepare, aev_computer) run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [ self._testModule(run_module, aev_computer, [
...@@ -90,7 +90,7 @@ class TestBenchmark(unittest.TestCase): ...@@ -90,7 +90,7 @@ class TestBenchmark(unittest.TestCase):
]) ])
def testANIModel(self): def testANIModel(self):
aev_computer = torchani.SortedAEV() aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species) prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.NeuroChemNNP(aev_computer.species, model = torchani.models.NeuroChemNNP(aev_computer.species,
benchmark=True) benchmark=True)
......
import sys import os
import unittest
import torchani.data
if sys.version_info.major >= 3: path = os.path.dirname(os.path.realpath(__file__))
import os path = os.path.join(path, '../dataset')
import unittest
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
class TestDataset(unittest.TestCase): class TestDataset(unittest.TestCase):
def _test_chunksize(self, chunksize): def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize) ds = torchani.data.ANIDataset(path, chunksize)
for i, _ in ds: for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize) self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self): def testChunk64(self):
self._test_chunksize(64) self._test_chunksize(64)
def testChunk128(self): def testChunk128(self):
self._test_chunksize(128) self._test_chunksize(128)
def testChunk32(self): def testChunk32(self):
self._test_chunksize(32) self._test_chunksize(32)
def testChunk256(self): def testChunk256(self):
self._test_chunksize(256) self._test_chunksize(256)
if __name__ == '__main__':
unittest.main() if __name__ == '__main__':
unittest.main()
...@@ -13,30 +13,38 @@ class TestEnergies(unittest.TestCase): ...@@ -13,30 +13,38 @@ class TestEnergies(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 5e-5 self.tolerance = 5e-5
aev_computer = torchani.SortedAEV() aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species) self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species) shift_energy = torchani.EnergyShifter(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
nnp, shift_energy)
def _test_molecule(self, coordinates, species, energies): def testIsomers(self):
# generate a random permute
atoms = len(species)
randperm = torch.randperm(atoms)
coordinates = coordinates.index_select(1, randperm)
species = [species[i] for i in randperm.tolist()]
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_.squeeze()).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testGDB(self):
for i in range(N): for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, energies, _ = pickle.load(f) coordinates, species, _, _, energies, _ = pickle.load(f)
self._test_molecule(coordinates, species, energies) species, coordinates = self.prepare((species, coordinates))
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testPadding(self):
species_coordinates = []
energies = []
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append(
self.prepare((species, coordinates)))
energies.append(e)
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,19 +8,24 @@ class TestEnergyShifter(unittest.TestCase): ...@@ -8,19 +8,24 @@ class TestEnergyShifter(unittest.TestCase):
def setUp(self): def setUp(self):
self.tol = 1e-5 self.tol = 1e-5
self.species = torchani.SortedAEV().species self.species = torchani.AEVComputer().species
self.prepare = torchani.PrepareInput(self.species) self.prepare = torchani.PrepareInput(self.species)
self.shift_energy = torchani.EnergyShifter(self.species) self.shift_energy = torchani.EnergyShifter(self.species)
def testSAEMatch(self): def testSAEMatch(self):
species_coordinates = []
saes = []
for _ in range(10): for _ in range(10):
k = random.choice(range(5, 30)) k = random.choice(range(5, 30))
species = random.choices(self.species, k=k) species = random.choices(self.species, k=k)
species_tensor = self.prepare.species_to_tensor( coordinates = torch.empty(1, k, 3)
species, torch.device('cpu')) species_coordinates.append(self.prepare((species, coordinates)))
e1 = self.shift_energy.sae_from_list(species) e = self.shift_energy.sae_from_list(species)
e2 = self.shift_energy.sae_from_tensor(species_tensor) saes.append(e)
self.assertLess(abs(e1 - e2), self.tol) species, _ = torchani.padding.pad_and_batch(species_coordinates)
saes_ = self.shift_energy.sae_from_tensor(species)
saes = torch.tensor(saes, dtype=saes_.dtype, device=saes_.device)
self.assertLess((saes - saes_).abs().max(), self.tol)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,7 +18,7 @@ class TestEnsemble(unittest.TestCase): ...@@ -18,7 +18,7 @@ class TestEnsemble(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.buildin_ensemble n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV() aev = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev.species) prepare = torchani.PrepareInput(aev.species)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True) ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble) ensemble = torch.nn.Sequential(prepare, aev, ensemble)
......
...@@ -12,31 +12,44 @@ class TestForce(unittest.TestCase): ...@@ -12,31 +12,44 @@ class TestForce(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 1e-5 self.tolerance = 1e-5
aev_computer = torchani.SortedAEV() aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species) self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp) self.model = torch.nn.Sequential(self.prepare, aev_computer, nnp)
self.prepared2e = torch.nn.Sequential(aev_computer, nnp)
def _test_molecule(self, coordinates, species, forces):
# generate a random permute def testIsomers(self):
atoms = len(species) for i in range(N):
randperm = torch.randperm(atoms) datafile = os.path.join(path, 'test_data/{}'.format(i))
coordinates = coordinates.index_select(1, randperm) with open(datafile, 'rb') as f:
forces = forces.index_select(1, randperm) coordinates, species, _, _, _, forces = pickle.load(f)
species = [species[i] for i in randperm.tolist()] coordinates = torch.tensor(coordinates, requires_grad=True)
_, energies = self.model((species, coordinates))
coordinates = torch.tensor(coordinates, requires_grad=True) derivative = torch.autograd.grad(energies.sum(),
_, energies = self.model((species, coordinates)) coordinates)[0]
derivative = torch.autograd.grad(energies.sum(), coordinates)[0] max_diff = (forces + derivative).abs().max().item()
max_diff = (forces + derivative).abs().max().item() self.assertLess(max_diff, self.tolerance)
self.assertLess(max_diff, self.tolerance)
def testPadding(self):
def testGDB(self): species_coordinates = []
coordinates_forces = []
for i in range(N): for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f) coordinates, species, _, _, _, forces = pickle.load(f)
self._test_molecule(coordinates, species, forces) species, coordinates = self.prepare((species, coordinates))
coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
_, energies = self.prepared2e((species, coordinates))
energies = energies.sum()
for coordinates, forces in coordinates_forces:
derivative = torch.autograd.grad(energies, coordinates,
retain_graph=True)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,7 +17,7 @@ if sys.version_info.major >= 3: ...@@ -17,7 +17,7 @@ if sys.version_info.major >= 3:
class TestIgnite(unittest.TestCase): class TestIgnite(unittest.TestCase):
def testIgnite(self): def testIgnite(self):
aev_computer = torchani.SortedAEV() aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species) prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species) nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species) shift_energy = torchani.EnergyShifter(aev_computer.species)
......
import unittest
import torch
import torchani
class TestPadAndBatch(unittest.TestCase):
def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testTensorShape1NSpecies(self):
species1 = torch.LongTensor([[0, 2, 3, 1]])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testTensorSpecies(self):
species1 = torch.LongTensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.padding.present_species(species)
expected = torch.LongTensor([0, 1, 3, 7])
self.assertEqual((expected - present_species).abs().max().item(), 0)
if __name__ == '__main__':
unittest.main()
...@@ -2,11 +2,12 @@ from .energyshifter import EnergyShifter ...@@ -2,11 +2,12 @@ from .energyshifter import EnergyShifter
from . import models from . import models
from . import data from . import data
from . import ignite from . import ignite
from .aev import SortedAEV, PrepareInput from . import padding
from .aev import AEVComputer, PrepareInput
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble buildin_model_prefix, buildin_ensemble
__all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter', __all__ = ['PrepareInput', 'AEVComputer', 'EnergyShifter',
'models', 'data', 'ignite', 'models', 'data', 'padding', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir', 'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble'] 'buildin_model_prefix', 'buildin_ensemble']
...@@ -3,9 +3,10 @@ import itertools ...@@ -3,9 +3,10 @@ import itertools
import math import math
from .env import buildin_const_file from .env import buildin_const_file
from .benchmarked import BenchmarkedModule from .benchmarked import BenchmarkedModule
from . import padding
class AEVComputer(BenchmarkedModule): class AEVComputerBase(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length', __constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length',
'angular_sublength', 'angular_length', 'aev_length'] 'angular_sublength', 'angular_length', 'aev_length']
...@@ -34,7 +35,7 @@ class AEVComputer(BenchmarkedModule): ...@@ -34,7 +35,7 @@ class AEVComputer(BenchmarkedModule):
""" """
def __init__(self, benchmark=False, const_file=buildin_const_file): def __init__(self, benchmark=False, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark) super(AEVComputerBase, self).__init__(benchmark)
self.const_file = const_file self.const_file = const_file
# load constants from const file # load constants from const file
...@@ -132,31 +133,12 @@ class PrepareInput(torch.nn.Module): ...@@ -132,31 +133,12 @@ class PrepareInput(torch.nn.Module):
values = [indices[i] for i in species] values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=device) return torch.tensor(values, dtype=torch.long, device=device)
def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
species : torch.Tensor
Tensor storing species of each atom.
*tensors : tuple
Tensors of shape (conformations, atoms, ...) for data.
Returns
-------
(species, ...)
Tensors sorted by species.
"""
species, reverse = torch.sort(species)
new_tensors = []
for t in tensors:
new_tensors.append(t.index_select(1, reverse))
return (species, *new_tensors)
def forward(self, species_coordinates): def forward(self, species_coordinates):
species, coordinates = species_coordinates species, coordinates = species_coordinates
conformations = coordinates.shape[0]
species = self.species_to_tensor(species, coordinates.device) species = self.species_to_tensor(species, coordinates.device)
return self.sort_by_species(species, coordinates) species = species.expand(conformations, -1)
return species, coordinates
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
...@@ -188,7 +170,7 @@ def _cutoff_cosine(distances, cutoff): ...@@ -188,7 +170,7 @@ def _cutoff_cosine(distances, cutoff):
) )
class SortedAEV(AEVComputer): class AEVComputer(AEVComputerBase):
"""The AEV computer assuming input coordinates sorted by species """The AEV computer assuming input coordinates sorted by species
Attributes Attributes
...@@ -201,7 +183,7 @@ class SortedAEV(AEVComputer): ...@@ -201,7 +183,7 @@ class SortedAEV(AEVComputer):
""" """
def __init__(self, benchmark=False, const_file=buildin_const_file): def __init__(self, benchmark=False, const_file=buildin_const_file):
super(SortedAEV, self).__init__(benchmark, const_file) super(AEVComputer, self).__init__(benchmark, const_file)
if benchmark: if benchmark:
self.radial_subaev_terms = self._enable_benchmark( self.radial_subaev_terms = self._enable_benchmark(
self.radial_subaev_terms, 'radial terms') self.radial_subaev_terms, 'radial terms')
...@@ -295,7 +277,7 @@ class SortedAEV(AEVComputer): ...@@ -295,7 +277,7 @@ class SortedAEV(AEVComputer):
# flat the last 4 dimensions to view the subAEV as one dimension vector # flat the last 4 dimensions to view the subAEV as one dimension vector
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=-4)
def terms_and_indices(self, coordinates): def terms_and_indices(self, species, coordinates):
"""Compute radial and angular subAEV terms, and original indices. """Compute radial and angular subAEV terms, and original indices.
Terms will be sorted according to their distances to central atoms, Terms will be sorted according to their distances to central atoms,
...@@ -304,6 +286,9 @@ class SortedAEV(AEVComputer): ...@@ -304,6 +286,9 @@ class SortedAEV(AEVComputer):
Parameters Parameters
---------- ----------
species : torch.Tensor
The tensor that specifies the species of atoms in the molecule.
The tensor must have shape (conformations, atoms)
coordinates : torch.Tensor coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3) molecule. The tensor must have shape (conformations, atoms, 3)
...@@ -333,6 +318,10 @@ class SortedAEV(AEVComputer): ...@@ -333,6 +318,10 @@ class SortedAEV(AEVComputer):
distances = vec.norm(2, -1) distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances""" """Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = torch.where(padding_mask, torch.tensor(math.inf),
distances)
distances, indices = distances.sort(-1) distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0) min_distances, _ = distances.flatten(end_dim=1).min(0)
...@@ -369,14 +358,16 @@ class SortedAEV(AEVComputer): ...@@ -369,14 +358,16 @@ class SortedAEV(AEVComputer):
return tensor.index_select(dim, index1), \ return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2) tensor.index_select(dim, index2)
def compute_mask_r(self, species_r): def compute_mask_r(self, species, indices_r):
"""Partition indices according to their species, radial part """Partition indices according to their species, radial part
Parameters Parameters
---------- ----------
species_r : torch.Tensor indices_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors) storing Tensor of shape (conformations, atoms, neighbors).
species of neighbors. Let l = indices_r(i,j,k), then this means that
radial_terms(i,j,k,:) is in the subAEV term of conformation i
between atom j and atom l.
Returns Returns
------- -------
...@@ -384,11 +375,14 @@ class SortedAEV(AEVComputer): ...@@ -384,11 +375,14 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, neighbors, all species) Tensor of shape (conformations, atoms, neighbors, all species)
storing the mask for each species. storing the mask for each species.
""" """
species_r = species.gather(-1, indices_r)
"""Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors."""
mask_r = (species_r.unsqueeze(-1) == mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.EtaR.device)) torch.arange(len(self.species), device=self.EtaR.device))
return mask_r return mask_r
def compute_mask_a(self, species_a, present_species): def compute_mask_a(self, species, indices_a, present_species):
"""Partition indices according to their species, angular part """Partition indices according to their species, angular part
Parameters Parameters
...@@ -405,6 +399,7 @@ class SortedAEV(AEVComputer): ...@@ -405,6 +399,7 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, pairs, present species, Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair. present species) storing the mask for each pair.
""" """
species_a = species.gather(-1, indices_a)
species_a1, species_a2 = self.combinations(species_a, -1) species_a1, species_a2 = self.combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1) mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species) mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
...@@ -480,15 +475,17 @@ class SortedAEV(AEVComputer): ...@@ -480,15 +475,17 @@ class SortedAEV(AEVComputer):
def forward(self, species_coordinates): def forward(self, species_coordinates):
species, coordinates = species_coordinates species, coordinates = species_coordinates
present_species = species.unique(sorted=True)
radial_terms, angular_terms, indices_r, indices_a = \ present_species = padding.present_species(species)
self.terms_and_indices(coordinates)
species_r = species.take(indices_r) # TODO: remove this workaround after gather support broadcasting
mask_r = self.compute_mask_r(species_r) atoms = coordinates.shape[1]
species_a = species.take(indices_a) species_ = species.unsqueeze(1).expand(-1, atoms, -1)
mask_a = self.compute_mask_a(species_a, present_species)
radial_terms, angular_terms, indices_r, indices_a = \
self.terms_and_indices(species, coordinates)
mask_r = self.compute_mask_r(species_, indices_r)
mask_a = self.compute_mask_a(species_, indices_a, present_species)
radial, angular = self.assemble(radial_terms, angular_terms, radial, angular = self.assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a) present_species, mask_r, mask_a)
......
...@@ -27,7 +27,9 @@ class EnergyShifter(torch.nn.Module): ...@@ -27,7 +27,9 @@ class EnergyShifter(torch.nn.Module):
return sum(energies) return sum(energies)
def sae_from_tensor(self, species): def sae_from_tensor(self, species):
return self.self_energies_tensor[species].sum().item() self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, data): def subtract_from_dataset(self, data):
sae = self.sae_from_list(data['species']) sae = self.sae_from_list(data['species'])
...@@ -36,5 +38,5 @@ class EnergyShifter(torch.nn.Module): ...@@ -36,5 +38,5 @@ class EnergyShifter(torch.nn.Module):
def forward(self, species_energies): def forward(self, species_energies):
species, energies = species_energies species, energies = species_energies
sae = self.sae_from_tensor(species) sae = self.sae_from_tensor(species).to(energies.dtype)
return species, energies + sae return species, energies + sae
import torch import torch
from ..benchmarked import BenchmarkedModule from ..benchmarked import BenchmarkedModule
from .. import padding
class ANIModel(BenchmarkedModule): class ANIModel(BenchmarkedModule):
...@@ -21,6 +22,8 @@ class ANIModel(BenchmarkedModule): ...@@ -21,6 +22,8 @@ class ANIModel(BenchmarkedModule):
with the per atom output tensor with internal shape as input, and with the per atom output tensor with internal shape as input, and
desired reduction dimension as dim, and should reduce the input into desired reduction dimension as dim, and should reduce the input into
the tensor containing desired output. the tensor containing desired output.
padding_fill : float
Default value used to fill padding atoms
output_length : int output_length : int
Length of output of each submodel. Length of output of each submodel.
timers : dict timers : dict
...@@ -28,12 +31,13 @@ class ANIModel(BenchmarkedModule): ...@@ -28,12 +31,13 @@ class ANIModel(BenchmarkedModule):
forward : total time for the forward pass forward : total time for the forward pass
""" """
def __init__(self, species, suffixes, reducer, models, def __init__(self, species, suffixes, reducer, padding_fill, models,
benchmark=False): benchmark=False):
super(ANIModel, self).__init__(benchmark) super(ANIModel, self).__init__(benchmark)
self.species = species self.species = species
self.suffixes = suffixes self.suffixes = suffixes
self.reducer = reducer self.reducer = reducer
self.padding_fill = padding_fill
for i in models: for i in models:
setattr(self, i, models[i]) setattr(self, i, models[i])
...@@ -54,33 +58,28 @@ class ANIModel(BenchmarkedModule): ...@@ -54,33 +58,28 @@ class ANIModel(BenchmarkedModule):
Returns Returns
------- -------
torch.Tensor (species, output)
species : torch.Tensor
Tensor storing the species for each atom.
output : torch.Tensor
Pytorch tensor of shape (conformations, output_length) for the Pytorch tensor of shape (conformations, output_length) for the
output of each conformation. output of each conformation.
""" """
species, aev = species_aev species, aev = species_aev
conformations = aev.shape[0] species_ = species.flatten()
atoms = len(species) present_species = padding.present_species(species)
rev_species = species.__reversed__() aev = aev.flatten(0, 1)
species_dedup = species.unique() outputs = []
per_species_outputs = [] for suffix in self.suffixes:
species = species.tolist() output = torch.full_like(species_, self.padding_fill,
rev_species = rev_species.tolist() dtype=aev.dtype)
for s in species_dedup: for i in present_species:
begin = species.index(s) s = self.species[i]
end = atoms - rev_species.index(s) model_X = getattr(self, 'model_' + s + suffix)
part_atoms = end - begin mask = (species_ == i)
y = aev[:, begin:end, :].flatten(0, 1) input = aev.index_select(0, mask.nonzero().squeeze())
output[mask] = model_X(input).squeeze()
def apply_model(suffix): output = output.view_as(species)
model_X = getattr(self, 'model_' + outputs.append(self.reducer(output, dim=1))
self.species[s] + suffix)
return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys)
y = y.view(conformations, part_atoms, -1)
per_species_outputs.append(y)
per_species_outputs = torch.cat(per_species_outputs, dim=1) return species, sum(outputs) / len(outputs)
molecule_output = self.reducer(per_species_outputs, dim=1)
return species, molecule_output
import torch
from .ani_model import ANIModel from .ani_model import ANIModel
class CustomModel(ANIModel): class CustomModel(ANIModel):
def __init__(self, per_species, reducer, def __init__(self, per_species, reducer=torch.sum, padding_fill=0,
derivative=False, derivative_graph=False, benchmark=False): derivative=False, derivative_graph=False, benchmark=False):
"""Custom single model, no ensemble """Custom single model, no ensemble
...@@ -21,4 +22,5 @@ class CustomModel(ANIModel): ...@@ -21,4 +22,5 @@ class CustomModel(ANIModel):
for i in per_species: for i in per_species:
models['model_' + i] = per_species[i] models['model_' + i] = per_species[i]
super(CustomModel, self).__init__(list(per_species.keys()), suffixes, super(CustomModel, self).__init__(list(per_species.keys()), suffixes,
reducer, models, benchmark) reducer, padding_fill, models,
benchmark)
...@@ -40,8 +40,6 @@ class NeuroChemNNP(ANIModel): ...@@ -40,8 +40,6 @@ class NeuroChemNNP(ANIModel):
network_dirs.append(network_dir) network_dirs.append(network_dir)
suffixes.append(suffix) suffixes.append(suffix)
reducer = torch.sum
models = {} models = {}
for network_dir, suffix in zip(network_dirs, suffixes): for network_dir, suffix in zip(network_dirs, suffixes):
for i in species: for i in species:
...@@ -49,5 +47,5 @@ class NeuroChemNNP(ANIModel): ...@@ -49,5 +47,5 @@ class NeuroChemNNP(ANIModel):
network_dir, 'ANN-{}.nnf'.format(i)) network_dir, 'ANN-{}.nnf'.format(i))
model_X = NeuroChemAtomicNetwork(filename) model_X = NeuroChemAtomicNetwork(filename)
models['model_' + i + suffix] = model_X models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(species, suffixes, reducer, super(NeuroChemNNP, self).__init__(species, suffixes, torch.sum,
models, benchmark) 0, models, benchmark)
import torch
def pad_and_batch(species_coordinates):
max_atoms = max([c.shape[1] for _, c in species_coordinates])
species = []
coordinates = []
for s, c in species_coordinates:
natoms = c.shape[1]
if len(s.shape) == 1:
s = s.unsqueeze(0)
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padding = torch.full((c.shape[0], max_atoms - natoms, 3), 0,
dtype=c.dtype, device=c.device)
c = torch.cat([c, padding], dim=1)
s = s.expand(c.shape[0], max_atoms)
species.append(s)
coordinates.append(c)
return torch.cat(species), torch.cat(coordinates)
def present_species(species):
present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1:
present_species = present_species[1:]
return present_species
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment