Commit 1055f1f5 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Add element names to ANIModel (#398)

* Add element names to ANIModel

* nc trainer
parent 66c3743c
...@@ -18,7 +18,10 @@ class TestNeuroChem(unittest.TestCase): ...@@ -18,7 +18,10 @@ class TestNeuroChem(unittest.TestCase):
# test if loader construct correct model # test if loader construct correct model
self.assertEqual(trainer.aev_computer.aev_length, 384) self.assertEqual(trainer.aev_computer.aev_length, 384)
m = trainer.nn m = trainer.nn
H, C, N, O = m # noqa: E741 H = m['H']
C = m['C']
N = m['N']
O = m['O'] # noqa: E741
self.assertIsInstance(H[0], torch.nn.Linear) self.assertIsInstance(H[0], torch.nn.Linear)
self.assertListEqual(list(H[0].weight.shape), [160, 384]) self.assertListEqual(list(H[0].weight.shape), [160, 384])
self.assertIsInstance(H[1], torch.nn.CELU) self.assertIsInstance(H[1], torch.nn.CELU)
......
...@@ -15,6 +15,7 @@ from ..nn import ANIModel, Ensemble, Gaussian, Sequential ...@@ -15,6 +15,7 @@ from ..nn import ANIModel, Ensemble, Gaussian, Sequential
from ..utils import EnergyShifter, ChemicalSymbolsToInts from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer from ..aev import AEVComputer
from ..optim import AdamW from ..optim import AdamW
from collections import OrderedDict
class Constants(collections.abc.Mapping): class Constants(collections.abc.Mapping):
...@@ -240,10 +241,10 @@ def load_model(species, dir_): ...@@ -240,10 +241,10 @@ def load_model(species, dir_):
chemical symbols of each supported atom type in correct order. chemical symbols of each supported atom type in correct order.
dir_ (str): String for directory storing network configurations. dir_ (str): String for directory storing network configurations.
""" """
models = [] models = OrderedDict()
for i in species: for i in species:
filename = os.path.join(dir_, 'ANN-{}.nnf'.format(i)) filename = os.path.join(dir_, 'ANN-{}.nnf'.format(i))
models.append(load_atomic_network(filename)) models[i] = load_atomic_network(filename)
return ANIModel(models) return ANIModel(models)
...@@ -496,8 +497,8 @@ if sys.version_info[0] > 2: ...@@ -496,8 +497,8 @@ if sys.version_info[0] > 2:
input_size, network_setup = network_setup input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length: if input_size != self.aev_computer.aev_length:
raise ValueError('AEV size and input size does not match') raise ValueError('AEV size and input size does not match')
atomic_nets = {} atomic_nets = OrderedDict()
for atom_type in network_setup: for atom_type in self.consts.species:
layers = network_setup[atom_type] layers = network_setup[atom_type]
modules = [] modules = []
i = input_size i = input_size
...@@ -537,7 +538,7 @@ if sys.version_info[0] > 2: ...@@ -537,7 +538,7 @@ if sys.version_info[0] > 2:
'unrecognized parameter in layer setup') 'unrecognized parameter in layer setup')
i = o i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules) atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.nn = ANIModel([atomic_nets[s] for s in self.consts.species]) self.nn = ANIModel(atomic_nets)
# initialize weights and biases # initialize weights and biases
self.nn.apply(init_params) self.nn.apply(init_params)
......
import torch import torch
from collections import OrderedDict
from torch import Tensor from torch import Tensor
from typing import Tuple, NamedTuple, Optional from typing import Tuple, NamedTuple, Optional
...@@ -13,7 +14,7 @@ class SpeciesCoordinates(NamedTuple): ...@@ -13,7 +14,7 @@ class SpeciesCoordinates(NamedTuple):
coordinates: Tensor coordinates: Tensor
class ANIModel(torch.nn.ModuleList): class ANIModel(torch.nn.ModuleDict):
"""ANI model that compute energies from species and AEVs. """ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing Different atom types might have different modules, when computing
...@@ -31,6 +32,18 @@ class ANIModel(torch.nn.ModuleList): ...@@ -31,6 +32,18 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`. module by putting the same reference in :attr:`modules`.
""" """
@staticmethod
def ensureOrderedDict(modules):
if isinstance(modules, OrderedDict):
return modules
od = OrderedDict()
for i, m in enumerate(modules):
od[str(i)] = m
return od
def __init__(self, modules):
super(ANIModel, self).__init__(self.ensureOrderedDict(modules))
def forward(self, species_aev: Tuple[Tensor, Tensor], def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
...@@ -42,7 +55,7 @@ class ANIModel(torch.nn.ModuleList): ...@@ -42,7 +55,7 @@ class ANIModel(torch.nn.ModuleList):
output = aev.new_zeros(species_.shape) output = aev.new_zeros(species_.shape)
for i, m in enumerate(self): for i, (_, m) in enumerate(self.items()):
mask = (species_ == i) mask = (species_ == i)
midx = mask.nonzero().flatten() midx = mask.nonzero().flatten()
if midx.shape[0] > 0: if midx.shape[0] > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment