Unverified Commit 629bc698 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

make aev computer take a single tuple as input (#37)

parent d07bd02c
...@@ -16,7 +16,9 @@ class TestAEV(unittest.TestCase): ...@@ -16,7 +16,9 @@ class TestAEV(unittest.TestCase):
def _test_molecule(self, coordinates, species, expected_radial, def _test_molecule(self, coordinates, species, expected_radial,
expected_angular): expected_angular):
radial, angular = self.aev(coordinates, species) aev = self.aev((coordinates, species))
radial = aev[..., :self.aev.radial_length]
angular = aev[..., self.aev.radial_length:]
radial_diff = expected_radial - radial radial_diff = expected_radial - radial
radial_max_error = torch.max(torch.abs(radial_diff)).item() radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular angular_diff = expected_angular - angular
......
...@@ -41,7 +41,10 @@ class TestBenchmark(unittest.TestCase): ...@@ -41,7 +41,10 @@ class TestBenchmark(unittest.TestCase):
self.assertEqual(module.timers[i], 0) self.assertEqual(module.timers[i], 0)
old_timers = copy.copy(module.timers) old_timers = copy.copy(module.timers)
for _ in range(self.count): for _ in range(self.count):
module(self.coordinates, self.species) if isinstance(module, torchani.aev.AEVComputer):
module((self.coordinates, self.species))
else:
module(self.coordinates, self.species)
for i in keys: for i in keys:
self.assertLess(old_timers[i], module.timers[i]) self.assertLess(old_timers[i], module.timers[i])
for i in asserts: for i in asserts:
...@@ -90,16 +93,16 @@ class TestBenchmark(unittest.TestCase): ...@@ -90,16 +93,16 @@ class TestBenchmark(unittest.TestCase):
'total>mask_r', 'total>mask_a' 'total>mask_r', 'total>mask_a'
]) ])
def testModelOnAEV(self): def testANIModel(self):
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV(
dtype=self.dtype, device=self.device) dtype=self.dtype, device=self.device)
model = torchani.models.NeuroChemNNP( model = torchani.models.NeuroChemNNP(
aev_computer, benchmark=True) aev_computer, benchmark=True)
self._testModule(model, ['forward>aev', 'forward>nn']) self._testModule(model, ['forward>nn'])
model = torchani.models.NeuroChemNNP( model = torchani.models.NeuroChemNNP(
aev_computer, benchmark=True, derivative=True) aev_computer, benchmark=True, derivative=True)
self._testModule( self._testModule(
model, ['forward>aev', 'forward>nn', 'forward>derivative']) model, ['forward>nn', 'forward>derivative'])
if __name__ == '__main__': if __name__ == '__main__':
......
import torch import torch
import itertools import itertools
import numpy import numpy
from .aev_base import AEVComputer
from .env import buildin_const_file, default_dtype, default_device from .env import buildin_const_file, default_dtype, default_device
from .benchmarked import BenchmarkedModule
class AEVComputer(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'dtype', 'device', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
"""Base class of various implementations of AEV computer
Attributes
----------
benchmark : boolean
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
Cutoff radius
EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
radial_sublength : int
The length of radial subaev of a single species
radial_length : int
The length of full radial aev
angular_sublength : int
The length of angular subaev of a single species
angular_length : int
The length of full angular aev
aev_length : int
The length of full aev
"""
def __init__(self, benchmark=False, dtype=default_dtype,
device=default_device, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark)
self.dtype = dtype
self.const_file = const_file
self.device = device
# load constants from const file
with open(const_file) as f:
for i in f:
try:
line = [x.strip() for x in i.split('=')]
name = line[0]
value = line[1]
if name == 'Rcr' or name == 'Rca':
setattr(self, name, float(value))
elif name in ['EtaR', 'ShfR', 'Zeta',
'ShfZ', 'EtaA', 'ShfA']:
value = [float(x.strip()) for x in value.replace(
'[', '').replace(']', '').split(',')]
value = torch.tensor(value, dtype=dtype, device=device)
setattr(self, name, value)
elif name == 'Atyp':
value = [x.strip() for x in value.replace(
'[', '').replace(']', '').split(',')]
self.species = value
except Exception:
raise ValueError('unable to parse const file')
# Compute lengths
self.radial_sublength = self.EtaR.shape[0] * self.ShfR.shape[0]
self.radial_length = len(self.species) * self.radial_sublength
self.angular_sublength = self.EtaA.shape[0] * \
self.Zeta.shape[0] * self.ShfA.shape[0] * self.ShfZ.shape[0]
species = len(self.species)
self.angular_length = int(
(species * (species + 1)) / 2) * self.angular_sublength
self.aev_length = self.radial_length + self.angular_length
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.EtaR = self.EtaR.view(-1, 1)
self.ShfR = self.ShfR.view(1, -1)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self.EtaA = self.EtaA.view(-1, 1, 1, 1)
self.Zeta = self.Zeta.view(1, -1, 1, 1)
self.ShfA = self.ShfA.view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1)
def sort_by_species(self, data, species):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
data : torch.Tensor
Tensor of shape (conformations, atoms, ...) for data.
species : list
List storing species of each atom.
Returns
-------
(torch.Tensor, list)
Tuple of (sorted data, sorted species).
"""
atoms = list(zip(species, torch.unbind(data, 1)))
atoms = sorted(atoms, key=lambda x: self.species.index(x[0]))
species = [s for s, _ in atoms]
data = torch.stack([c for _, c in atoms], dim=1)
return data, species
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
Parameters
----------
(coordinates, species)
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor
of `dtype`. The radial AEV must be of shape
(conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
"""
raise NotImplementedError('subclass must override this method')
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
...@@ -353,7 +482,8 @@ class SortedAEV(AEVComputer): ...@@ -353,7 +482,8 @@ class SortedAEV(AEVComputer):
return radial_aevs, torch.cat(angular_aevs, dim=2) return radial_aevs, torch.cat(angular_aevs, dim=2)
def forward(self, coordinates, species): def forward(self, coordinates_species):
coordinates, species = coordinates_species
species = self.species_to_tensor(species) species = self.species_to_tensor(species)
present_species = species.unique(sorted=True) present_species = species.unique(sorted=True)
...@@ -365,24 +495,7 @@ class SortedAEV(AEVComputer): ...@@ -365,24 +495,7 @@ class SortedAEV(AEVComputer):
species_a = species[indices_a] species_a = species[indices_a]
mask_a = self.compute_mask_a(species_a, present_species) mask_a = self.compute_mask_a(species_a, present_species)
return self.assemble(radial_terms, angular_terms, present_species, radial, angular = self.assemble(radial_terms, angular_terms,
mask_r, mask_a) present_species, mask_r, mask_a)
fullaev = torch.cat([radial, angular], dim=2)
def export_radial_subaev_onnx(self, filename): return fullaev
"""Export the operation that compute radial subaev into onnx format
Parameters
----------
filename : string
Name of the file to store exported networks.
"""
class M(torch.nn.Module):
def __init__(self, outerself):
super(M, self).__init__()
self.outerself = outerself
def forward(self, center, neighbors):
return self.outerself.radial_subaev(center, neighbors)
dummy_center = torch.randn(1, 3)
dummy_neighbors = torch.randn(1, 5, 3)
torch.onnx.export(M(self), (dummy_center, dummy_neighbors), filename)
import torch
from .env import buildin_const_file, default_dtype, default_device
from .benchmarked import BenchmarkedModule
class AEVComputer(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'dtype', 'device', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
"""Base class of various implementations of AEV computer
Attributes
----------
benchmark : boolean
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
Cutoff radius
EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
radial_sublength : int
The length of radial subaev of a single species
radial_length : int
The length of full radial aev
angular_sublength : int
The length of angular subaev of a single species
angular_length : int
The length of full angular aev
aev_length : int
The length of full aev
"""
def __init__(self, benchmark=False, dtype=default_dtype,
device=default_device, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark)
self.dtype = dtype
self.const_file = const_file
self.device = device
# load constants from const file
with open(const_file) as f:
for i in f:
try:
line = [x.strip() for x in i.split('=')]
name = line[0]
value = line[1]
if name == 'Rcr' or name == 'Rca':
setattr(self, name, float(value))
elif name in ['EtaR', 'ShfR', 'Zeta',
'ShfZ', 'EtaA', 'ShfA']:
value = [float(x.strip()) for x in value.replace(
'[', '').replace(']', '').split(',')]
value = torch.tensor(value, dtype=dtype, device=device)
setattr(self, name, value)
elif name == 'Atyp':
value = [x.strip() for x in value.replace(
'[', '').replace(']', '').split(',')]
self.species = value
except Exception:
raise ValueError('unable to parse const file')
# Compute lengths
self.radial_sublength = self.EtaR.shape[0] * self.ShfR.shape[0]
self.radial_length = len(self.species) * self.radial_sublength
self.angular_sublength = self.EtaA.shape[0] * \
self.Zeta.shape[0] * self.ShfA.shape[0] * self.ShfZ.shape[0]
species = len(self.species)
self.angular_length = int(
(species * (species + 1)) / 2) * self.angular_sublength
self.aev_length = self.radial_length + self.angular_length
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.EtaR = self.EtaR.view(-1, 1)
self.ShfR = self.ShfR.view(1, -1)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self.EtaA = self.EtaA.view(-1, 1, 1, 1)
self.Zeta = self.Zeta.view(1, -1, 1, 1)
self.ShfA = self.ShfA.view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1)
def sort_by_species(self, data, species):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
data : torch.Tensor
Tensor of shape (conformations, atoms, ...) for data.
species : list
List storing species of each atom.
Returns
-------
(torch.Tensor, list)
Tuple of (sorted data, sorted species).
"""
atoms = list(zip(species, torch.unbind(data, 1)))
atoms = sorted(atoms, key=lambda x: self.species.index(x[0]))
species = [s for s, _ in atoms]
data = torch.stack([c for _, c in atoms], dim=1)
return data, species
def forward(self, coordinates, species):
"""Compute AEV from coordinates and species
Parameters
----------
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor
of `dtype`. The radial AEV must be of shape
(conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
"""
raise NotImplementedError('subclass must override this method')
from ..aev_base import AEVComputer from ..aev import AEVComputer
import torch import torch
from ..benchmarked import BenchmarkedModule from ..benchmarked import BenchmarkedModule
...@@ -67,7 +67,6 @@ class ANIModel(BenchmarkedModule): ...@@ -67,7 +67,6 @@ class ANIModel(BenchmarkedModule):
'derivative can only be computed for output length 1') 'derivative can only be computed for output length 1')
if benchmark: if benchmark:
self.compute_aev = self._enable_benchmark(self.compute_aev, 'aev')
self.aev_to_output = self._enable_benchmark( self.aev_to_output = self._enable_benchmark(
self.aev_to_output, 'nn') self.aev_to_output, 'nn')
if derivative: if derivative:
...@@ -75,27 +74,6 @@ class ANIModel(BenchmarkedModule): ...@@ -75,27 +74,6 @@ class ANIModel(BenchmarkedModule):
self.compute_derivative, 'derivative') self.compute_derivative, 'derivative')
self.forward = self._enable_benchmark(self.forward, 'forward') self.forward = self._enable_benchmark(self.forward, 'forward')
def compute_aev(self, coordinates, species):
"""Compute full AEV
Parameters
----------
coordinates : torch.Tensor
The pytorch tensor of shape (conformations, atoms, 3) storing
the coordinates of all atoms of all conformations.
species : list of string
List of string storing the species for each atom.
Returns
-------
torch.Tensor
Pytorch tensor of shape (conformations, atoms, aev_length) storing
the computed AEVs.
"""
radial_aev, angular_aev = self.aev_computer(coordinates, species)
fullaev = torch.cat([radial_aev, angular_aev], dim=2)
return fullaev
def aev_to_output(self, aev, species): def aev_to_output(self, aev, species):
"""Compute output from aev """Compute output from aev
...@@ -173,7 +151,7 @@ class ANIModel(BenchmarkedModule): ...@@ -173,7 +151,7 @@ class ANIModel(BenchmarkedModule):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
_coordinates, _species = self.aev_computer.sort_by_species( _coordinates, _species = self.aev_computer.sort_by_species(
coordinates, species) coordinates, species)
aev = self.compute_aev(_coordinates, _species) aev = self.aev_computer((_coordinates, _species))
output = self.aev_to_output(aev, _species) output = self.aev_to_output(aev, _species)
if not self.derivative: if not self.derivative:
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment