Unverified Commit 2cb357bd authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

make aev computer accept species as tensor (#38)

parent 629bc698
......@@ -16,6 +16,7 @@ class TestAEV(unittest.TestCase):
def _test_molecule(self, coordinates, species, expected_radial,
expected_angular):
species = self.aev.species_to_tensor(species)
aev = self.aev((coordinates, species))
radial = aev[..., :self.aev.radial_length]
angular = aev[..., self.aev.radial_length:]
......
......@@ -42,7 +42,8 @@ class TestBenchmark(unittest.TestCase):
old_timers = copy.copy(module.timers)
for _ in range(self.count):
if isinstance(module, torchani.aev.AEVComputer):
module((self.coordinates, self.species))
species = module.species_to_tensor(self.species)
module((self.coordinates, species))
else:
module(self.coordinates, self.species)
for i in keys:
......
......@@ -18,7 +18,6 @@ class TestEnsemble(unittest.TestCase):
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV(device=torch.device('cpu'))
coordinates, species = aev.sort_by_species(coordinates, species)
ensemble = torchani.models.NeuroChemNNP(aev, derivative=True,
ensemble=True)
models = [torchani.models.
......
......@@ -11,7 +11,7 @@ if sys.version_info.major >= 3:
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 8
chunksize = 4
threshold = 1e-5
dtype = torch.float32
device = torch.device('cpu')
......
import torch
import itertools
import numpy
import math
from .env import buildin_const_file, default_dtype, default_device
from .benchmarked import BenchmarkedModule
......@@ -89,26 +89,26 @@ class AEVComputer(BenchmarkedModule):
self.ShfA = self.ShfA.view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1)
def sort_by_species(self, data, species):
def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
data : torch.Tensor
Tensor of shape (conformations, atoms, ...) for data.
species : list
List storing species of each atom.
species : torch.Tensor
Tensor storing species of each atom.
*tensors : tuple
Tensors of shape (conformations, atoms, ...) for data.
Returns
-------
(torch.Tensor, list)
Tuple of (sorted data, sorted species).
(species, ...)
Tensors sorted by species.
"""
atoms = list(zip(species, torch.unbind(data, 1)))
atoms = sorted(atoms, key=lambda x: self.species.index(x[0]))
species = [s for s, _ in atoms]
data = torch.stack([c for _, c in atoms], dim=1)
return data, species
species, reverse = torch.sort(species)
new_tensors = []
for t in tensors:
new_tensors.append(t.index_select(1, reverse))
return (species, *tensors)
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
......@@ -158,7 +158,7 @@ def _cutoff_cosine(distances, cutoff):
"""
return torch.where(
distances <= cutoff,
0.5 * torch.cos(numpy.pi * distances / cutoff) + 0.5,
0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
torch.zeros_like(distances)
)
......@@ -484,7 +484,6 @@ class SortedAEV(AEVComputer):
def forward(self, coordinates_species):
coordinates, species = coordinates_species
species = self.species_to_tensor(species)
present_species = species.unique(sorted=True)
radial_terms, angular_terms, indices_r, indices_a = \
......
......@@ -82,8 +82,8 @@ class ANIModel(BenchmarkedModule):
aev : torch.Tensor
Pytorch tensor of shape (conformations, atoms, aev_length) storing
the computed AEVs.
species : list of string
List of string storing the species for each atom.
species : torch.Tensor
Tensor storing the species for each atom.
Returns
-------
......@@ -93,17 +93,19 @@ class ANIModel(BenchmarkedModule):
"""
conformations = aev.shape[0]
atoms = len(species)
rev_species = species[::-1]
species_dedup = sorted(
set(species), key=self.aev_computer.species.index)
rev_species = species.__reversed__()
species_dedup = species.unique()
per_species_outputs = []
species = species.tolist()
rev_species = rev_species.tolist()
for s in species_dedup:
begin = species.index(s)
end = atoms - rev_species.index(s)
y = aev[:, begin:end, :].reshape(-1, self.aev_computer.aev_length)
def apply_model(suffix):
model_X = getattr(self, 'model_' + s + suffix)
model_X = getattr(self, 'model_' +
self.aev_computer.species[s] + suffix)
return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys)
......@@ -145,12 +147,13 @@ class ANIModel(BenchmarkedModule):
when the derivative is off, and the second tensor is a tensor of
shape (conformation, atoms, 3) storing the d(output)/dR.
"""
species = self.aev_computer.species_to_tensor(species)
if not self.derivative:
coordinates = coordinates.detach()
else:
coordinates = torch.tensor(coordinates, requires_grad=True)
_coordinates, _species = self.aev_computer.sort_by_species(
coordinates, species)
_species, _coordinates, = self.aev_computer.sort_by_species(
species, coordinates)
aev = self.aev_computer((_coordinates, _species))
output = self.aev_to_output(aev, _species)
if not self.derivative:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment