Unverified Commit b7cab4f1 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Support batching non isomers in models and aev computer (#62)

parent b63f6e40
......@@ -8,7 +8,7 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file)
aev_computer = torchani.AEVComputer(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
......
......@@ -18,10 +18,9 @@ def atomic():
def get_or_create_model(filename, benchmark=False,
device=torch.device('cpu')):
aev_computer = torchani.SortedAEV(benchmark=benchmark)
aev_computer = torchani.AEVComputer(benchmark=benchmark)
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.CustomModel(
reducer=torch.sum,
benchmark=benchmark,
per_species={
'C': atomic(),
......
......@@ -36,7 +36,7 @@ parser = parser.parse_args()
# load modules and datasets
device = torch.device(parser.device)
aev_computer = torchani.SortedAEV(const_file=parser.const_file)
aev_computer = torchani.AEVComputer(const_file=parser.const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species,
from_=parser.network_dir,
......
......@@ -11,7 +11,7 @@ N = 97
class TestAEV(unittest.TestCase):
def setUp(self):
self.aev_computer = torchani.SortedAEV()
self.aev_computer = torchani.AEVComputer()
self.radial_length = self.aev_computer.radial_length
self.prepare = torchani.PrepareInput(self.aev_computer.species)
self.aev = torch.nn.Sequential(
......@@ -20,20 +20,9 @@ class TestAEV(unittest.TestCase):
)
self.tolerance = 1e-5
def _test_molecule(self, coordinates, species, expected_radial,
expected_angular):
# compute aev using aev computer, sorted
_, aev = self.aev((species, coordinates))
def _assertAEVEqual(self, expected_radial, expected_angular, aev):
radial = aev[..., :self.radial_length]
angular = aev[..., self.radial_length:]
# manually sort expected values
species = self.prepare.species_to_tensor(species,
self.aev_computer.EtaR.device)
_, reverse = torch.sort(species)
expected_radial = expected_radial.index_select(1, reverse)
expected_angular = expected_angular.index_select(1, reverse)
radial_diff = expected_radial - radial
radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular
......@@ -41,12 +30,35 @@ class TestAEV(unittest.TestCase):
self.assertLess(radial_max_error, self.tolerance)
self.assertLess(angular_max_error, self.tolerance)
def testGDB(self):
def testIsomers(self):
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _ \
= pickle.load(f)
_, aev = self.aev((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev)
def testPadding(self):
species_coordinates = []
radial_angular = []
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, radial, angular, _, _ = pickle.load(f)
self._test_molecule(coordinates, species, radial, angular)
species_coordinates.append(
self.prepare((species, coordinates)))
radial_angular.append((radial, angular))
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
_, aev = self.aev_computer((species, coordinates))
start = 0
for expected_radial, expected_angular in radial_angular:
conformations = expected_radial.shape[0]
atoms = expected_radial.shape[1]
aev_ = aev[start:start+conformations, 0:atoms]
start += conformations
self._assertAEVEqual(expected_radial, expected_angular, aev_)
if __name__ == '__main__':
......
......@@ -78,7 +78,7 @@ class TestBenchmark(unittest.TestCase):
self.assertEqual(result_module.timers[i], 0)
def testAEV(self):
aev_computer = torchani.SortedAEV(benchmark=True)
aev_computer = torchani.AEVComputer(benchmark=True)
prepare = torchani.PrepareInput(aev_computer.species)
run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [
......@@ -90,7 +90,7 @@ class TestBenchmark(unittest.TestCase):
])
def testANIModel(self):
aev_computer = torchani.SortedAEV()
aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.NeuroChemNNP(aev_computer.species,
benchmark=True)
......
import sys
import os
import unittest
import torchani.data
if sys.version_info.major >= 3:
import os
import unittest
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
class TestDataset(unittest.TestCase):
class TestDataset(unittest.TestCase):
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize)
for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize)
for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self):
self._test_chunksize(64)
def testChunk64(self):
self._test_chunksize(64)
def testChunk128(self):
self._test_chunksize(128)
def testChunk128(self):
self._test_chunksize(128)
def testChunk32(self):
self._test_chunksize(32)
def testChunk32(self):
self._test_chunksize(32)
def testChunk256(self):
self._test_chunksize(256)
def testChunk256(self):
self._test_chunksize(256)
if __name__ == '__main__':
unittest.main()
if __name__ == '__main__':
unittest.main()
......@@ -13,30 +13,38 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
aev_computer = torchani.AEVComputer()
self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer,
nnp, shift_energy)
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
def _test_molecule(self, coordinates, species, energies):
# generate a random permute
atoms = len(species)
randperm = torch.randperm(atoms)
coordinates = coordinates.index_select(1, randperm)
species = [species[i] for i in randperm.tolist()]
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_.squeeze()).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testGDB(self):
def testIsomers(self):
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, energies, _ = pickle.load(f)
self._test_molecule(coordinates, species, energies)
species, coordinates = self.prepare((species, coordinates))
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testPadding(self):
species_coordinates = []
energies = []
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append(
self.prepare((species, coordinates)))
energies.append(e)
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
if __name__ == '__main__':
......
......@@ -8,19 +8,24 @@ class TestEnergyShifter(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.species = torchani.SortedAEV().species
self.species = torchani.AEVComputer().species
self.prepare = torchani.PrepareInput(self.species)
self.shift_energy = torchani.EnergyShifter(self.species)
def testSAEMatch(self):
species_coordinates = []
saes = []
for _ in range(10):
k = random.choice(range(5, 30))
species = random.choices(self.species, k=k)
species_tensor = self.prepare.species_to_tensor(
species, torch.device('cpu'))
e1 = self.shift_energy.sae_from_list(species)
e2 = self.shift_energy.sae_from_tensor(species_tensor)
self.assertLess(abs(e1 - e2), self.tol)
coordinates = torch.empty(1, k, 3)
species_coordinates.append(self.prepare((species, coordinates)))
e = self.shift_energy.sae_from_list(species)
saes.append(e)
species, _ = torchani.padding.pad_and_batch(species_coordinates)
saes_ = self.shift_energy.sae_from_tensor(species)
saes = torch.tensor(saes, dtype=saes_.dtype, device=saes_.device)
self.assertLess((saes - saes_).abs().max(), self.tol)
if __name__ == '__main__':
......
......@@ -18,7 +18,7 @@ class TestEnsemble(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV()
aev = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev.species)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble)
......
......@@ -12,31 +12,44 @@ class TestForce(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-5
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
aev_computer = torchani.AEVComputer()
self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, forces):
# generate a random permute
atoms = len(species)
randperm = torch.randperm(atoms)
coordinates = coordinates.index_select(1, randperm)
forces = forces.index_select(1, randperm)
species = [species[i] for i in randperm.tolist()]
coordinates = torch.tensor(coordinates, requires_grad=True)
_, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testGDB(self):
self.model = torch.nn.Sequential(self.prepare, aev_computer, nnp)
self.prepared2e = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self):
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f)
coordinates = torch.tensor(coordinates, requires_grad=True)
_, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(),
coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testPadding(self):
species_coordinates = []
coordinates_forces = []
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f)
self._test_molecule(coordinates, species, forces)
species, coordinates = self.prepare((species, coordinates))
coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
_, energies = self.prepared2e((species, coordinates))
energies = energies.sum()
for coordinates, forces in coordinates_forces:
derivative = torch.autograd.grad(energies, coordinates,
retain_graph=True)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
if __name__ == '__main__':
......
......@@ -17,7 +17,7 @@ if sys.version_info.major >= 3:
class TestIgnite(unittest.TestCase):
def testIgnite(self):
aev_computer = torchani.SortedAEV()
aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
......
import unittest
import torch
import torchani
class TestPadAndBatch(unittest.TestCase):
def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testTensorShape1NSpecies(self):
species1 = torch.LongTensor([[0, 2, 3, 1]])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testTensorSpecies(self):
species1 = torch.LongTensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.padding.present_species(species)
expected = torch.LongTensor([0, 1, 3, 7])
self.assertEqual((expected - present_species).abs().max().item(), 0)
if __name__ == '__main__':
unittest.main()
......@@ -2,11 +2,12 @@ from .energyshifter import EnergyShifter
from . import models
from . import data
from . import ignite
from .aev import SortedAEV, PrepareInput
from . import padding
from .aev import AEVComputer, PrepareInput
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble
__all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter',
'models', 'data', 'ignite',
__all__ = ['PrepareInput', 'AEVComputer', 'EnergyShifter',
'models', 'data', 'padding', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble']
......@@ -3,9 +3,10 @@ import itertools
import math
from .env import buildin_const_file
from .benchmarked import BenchmarkedModule
from . import padding
class AEVComputer(BenchmarkedModule):
class AEVComputerBase(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length',
'angular_sublength', 'angular_length', 'aev_length']
......@@ -34,7 +35,7 @@ class AEVComputer(BenchmarkedModule):
"""
def __init__(self, benchmark=False, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark)
super(AEVComputerBase, self).__init__(benchmark)
self.const_file = const_file
# load constants from const file
......@@ -132,31 +133,12 @@ class PrepareInput(torch.nn.Module):
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=device)
def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
species : torch.Tensor
Tensor storing species of each atom.
*tensors : tuple
Tensors of shape (conformations, atoms, ...) for data.
Returns
-------
(species, ...)
Tensors sorted by species.
"""
species, reverse = torch.sort(species)
new_tensors = []
for t in tensors:
new_tensors.append(t.index_select(1, reverse))
return (species, *new_tensors)
def forward(self, species_coordinates):
species, coordinates = species_coordinates
conformations = coordinates.shape[0]
species = self.species_to_tensor(species, coordinates.device)
return self.sort_by_species(species, coordinates)
species = species.expand(conformations, -1)
return species, coordinates
def _cutoff_cosine(distances, cutoff):
......@@ -188,7 +170,7 @@ def _cutoff_cosine(distances, cutoff):
)
class SortedAEV(AEVComputer):
class AEVComputer(AEVComputerBase):
"""The AEV computer assuming input coordinates sorted by species
Attributes
......@@ -201,7 +183,7 @@ class SortedAEV(AEVComputer):
"""
def __init__(self, benchmark=False, const_file=buildin_const_file):
super(SortedAEV, self).__init__(benchmark, const_file)
super(AEVComputer, self).__init__(benchmark, const_file)
if benchmark:
self.radial_subaev_terms = self._enable_benchmark(
self.radial_subaev_terms, 'radial terms')
......@@ -295,7 +277,7 @@ class SortedAEV(AEVComputer):
# flat the last 4 dimensions to view the subAEV as one dimension vector
return ret.flatten(start_dim=-4)
def terms_and_indices(self, coordinates):
def terms_and_indices(self, species, coordinates):
"""Compute radial and angular subAEV terms, and original indices.
Terms will be sorted according to their distances to central atoms,
......@@ -304,6 +286,9 @@ class SortedAEV(AEVComputer):
Parameters
----------
species : torch.Tensor
The tensor that specifies the species of atoms in the molecule.
The tensor must have shape (conformations, atoms)
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
......@@ -333,6 +318,10 @@ class SortedAEV(AEVComputer):
distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = torch.where(padding_mask, torch.tensor(math.inf),
distances)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
......@@ -369,14 +358,16 @@ class SortedAEV(AEVComputer):
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
def compute_mask_r(self, species_r):
def compute_mask_r(self, species, indices_r):
"""Partition indices according to their species, radial part
Parameters
----------
species_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors) storing
species of neighbors.
indices_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors).
Let l = indices_r(i,j,k), then this means that
radial_terms(i,j,k,:) is in the subAEV term of conformation i
between atom j and atom l.
Returns
-------
......@@ -384,11 +375,14 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, neighbors, all species)
storing the mask for each species.
"""
species_r = species.gather(-1, indices_r)
"""Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors."""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.EtaR.device))
return mask_r
def compute_mask_a(self, species_a, present_species):
def compute_mask_a(self, species, indices_a, present_species):
"""Partition indices according to their species, angular part
Parameters
......@@ -405,6 +399,7 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair.
"""
species_a = species.gather(-1, indices_a)
species_a1, species_a2 = self.combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
......@@ -480,15 +475,17 @@ class SortedAEV(AEVComputer):
def forward(self, species_coordinates):
species, coordinates = species_coordinates
present_species = species.unique(sorted=True)
radial_terms, angular_terms, indices_r, indices_a = \
self.terms_and_indices(coordinates)
present_species = padding.present_species(species)
species_r = species.take(indices_r)
mask_r = self.compute_mask_r(species_r)
species_a = species.take(indices_a)
mask_a = self.compute_mask_a(species_a, present_species)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
radial_terms, angular_terms, indices_r, indices_a = \
self.terms_and_indices(species, coordinates)
mask_r = self.compute_mask_r(species_, indices_r)
mask_a = self.compute_mask_a(species_, indices_a, present_species)
radial, angular = self.assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a)
......
......@@ -27,7 +27,9 @@ class EnergyShifter(torch.nn.Module):
return sum(energies)
def sae_from_tensor(self, species):
return self.self_energies_tensor[species].sum().item()
self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, data):
sae = self.sae_from_list(data['species'])
......@@ -36,5 +38,5 @@ class EnergyShifter(torch.nn.Module):
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae_from_tensor(species)
sae = self.sae_from_tensor(species).to(energies.dtype)
return species, energies + sae
import torch
from ..benchmarked import BenchmarkedModule
from .. import padding
class ANIModel(BenchmarkedModule):
......@@ -21,6 +22,8 @@ class ANIModel(BenchmarkedModule):
with the per atom output tensor with internal shape as input, and
desired reduction dimension as dim, and should reduce the input into
the tensor containing desired output.
padding_fill : float
Default value used to fill padding atoms
output_length : int
Length of output of each submodel.
timers : dict
......@@ -28,12 +31,13 @@ class ANIModel(BenchmarkedModule):
forward : total time for the forward pass
"""
def __init__(self, species, suffixes, reducer, models,
def __init__(self, species, suffixes, reducer, padding_fill, models,
benchmark=False):
super(ANIModel, self).__init__(benchmark)
self.species = species
self.suffixes = suffixes
self.reducer = reducer
self.padding_fill = padding_fill
for i in models:
setattr(self, i, models[i])
......@@ -54,33 +58,28 @@ class ANIModel(BenchmarkedModule):
Returns
-------
torch.Tensor
(species, output)
species : torch.Tensor
Tensor storing the species for each atom.
output : torch.Tensor
Pytorch tensor of shape (conformations, output_length) for the
output of each conformation.
"""
species, aev = species_aev
conformations = aev.shape[0]
atoms = len(species)
rev_species = species.__reversed__()
species_dedup = species.unique()
per_species_outputs = []
species = species.tolist()
rev_species = rev_species.tolist()
for s in species_dedup:
begin = species.index(s)
end = atoms - rev_species.index(s)
part_atoms = end - begin
y = aev[:, begin:end, :].flatten(0, 1)
def apply_model(suffix):
model_X = getattr(self, 'model_' +
self.species[s] + suffix)
return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys)
y = y.view(conformations, part_atoms, -1)
per_species_outputs.append(y)
species_ = species.flatten()
present_species = padding.present_species(species)
aev = aev.flatten(0, 1)
outputs = []
for suffix in self.suffixes:
output = torch.full_like(species_, self.padding_fill,
dtype=aev.dtype)
for i in present_species:
s = self.species[i]
model_X = getattr(self, 'model_' + s + suffix)
mask = (species_ == i)
input = aev.index_select(0, mask.nonzero().squeeze())
output[mask] = model_X(input).squeeze()
output = output.view_as(species)
outputs.append(self.reducer(output, dim=1))
per_species_outputs = torch.cat(per_species_outputs, dim=1)
molecule_output = self.reducer(per_species_outputs, dim=1)
return species, molecule_output
return species, sum(outputs) / len(outputs)
import torch
from .ani_model import ANIModel
class CustomModel(ANIModel):
def __init__(self, per_species, reducer,
def __init__(self, per_species, reducer=torch.sum, padding_fill=0,
derivative=False, derivative_graph=False, benchmark=False):
"""Custom single model, no ensemble
......@@ -21,4 +22,5 @@ class CustomModel(ANIModel):
for i in per_species:
models['model_' + i] = per_species[i]
super(CustomModel, self).__init__(list(per_species.keys()), suffixes,
reducer, models, benchmark)
reducer, padding_fill, models,
benchmark)
......@@ -40,8 +40,6 @@ class NeuroChemNNP(ANIModel):
network_dirs.append(network_dir)
suffixes.append(suffix)
reducer = torch.sum
models = {}
for network_dir, suffix in zip(network_dirs, suffixes):
for i in species:
......@@ -49,5 +47,5 @@ class NeuroChemNNP(ANIModel):
network_dir, 'ANN-{}.nnf'.format(i))
model_X = NeuroChemAtomicNetwork(filename)
models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(species, suffixes, reducer,
models, benchmark)
super(NeuroChemNNP, self).__init__(species, suffixes, torch.sum,
0, models, benchmark)
import torch
def pad_and_batch(species_coordinates):
max_atoms = max([c.shape[1] for _, c in species_coordinates])
species = []
coordinates = []
for s, c in species_coordinates:
natoms = c.shape[1]
if len(s.shape) == 1:
s = s.unsqueeze(0)
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padding = torch.full((c.shape[0], max_atoms - natoms, 3), 0,
dtype=c.dtype, device=c.device)
c = torch.cat([c, padding], dim=1)
s = s.expand(c.shape[0], max_atoms)
species.append(s)
coordinates.append(c)
return torch.cat(species), torch.cat(coordinates)
def present_species(species):
present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1:
present_species = present_species[1:]
return present_species
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment