Commit 1055f1f5 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Add element names to ANIModel (#398)

* Add element names to ANIModel

* nc trainer
parent 66c3743c
......@@ -18,7 +18,10 @@ class TestNeuroChem(unittest.TestCase):
# test if loader construct correct model
self.assertEqual(trainer.aev_computer.aev_length, 384)
m = trainer.nn
H, C, N, O = m # noqa: E741
H = m['H']
C = m['C']
N = m['N']
O = m['O'] # noqa: E741
self.assertIsInstance(H[0], torch.nn.Linear)
self.assertListEqual(list(H[0].weight.shape), [160, 384])
self.assertIsInstance(H[1], torch.nn.CELU)
......
......@@ -15,6 +15,7 @@ from ..nn import ANIModel, Ensemble, Gaussian, Sequential
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..optim import AdamW
from collections import OrderedDict
class Constants(collections.abc.Mapping):
......@@ -240,10 +241,10 @@ def load_model(species, dir_):
chemical symbols of each supported atom type in correct order.
dir_ (str): String for directory storing network configurations.
"""
models = []
models = OrderedDict()
for i in species:
filename = os.path.join(dir_, 'ANN-{}.nnf'.format(i))
models.append(load_atomic_network(filename))
models[i] = load_atomic_network(filename)
return ANIModel(models)
......@@ -496,8 +497,8 @@ if sys.version_info[0] > 2:
input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length:
raise ValueError('AEV size and input size does not match')
atomic_nets = {}
for atom_type in network_setup:
atomic_nets = OrderedDict()
for atom_type in self.consts.species:
layers = network_setup[atom_type]
modules = []
i = input_size
......@@ -537,7 +538,7 @@ if sys.version_info[0] > 2:
'unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.nn = ANIModel([atomic_nets[s] for s in self.consts.species])
self.nn = ANIModel(atomic_nets)
# initialize weights and biases
self.nn.apply(init_params)
......
import torch
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional
......@@ -13,7 +14,7 @@ class SpeciesCoordinates(NamedTuple):
coordinates: Tensor
class ANIModel(torch.nn.ModuleList):
class ANIModel(torch.nn.ModuleDict):
"""ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing
......@@ -31,6 +32,18 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`.
"""
@staticmethod
def ensureOrderedDict(modules):
if isinstance(modules, OrderedDict):
return modules
od = OrderedDict()
for i, m in enumerate(modules):
od[str(i)] = m
return od
def __init__(self, modules):
super(ANIModel, self).__init__(self.ensureOrderedDict(modules))
def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
......@@ -42,7 +55,7 @@ class ANIModel(torch.nn.ModuleList):
output = aev.new_zeros(species_.shape)
for i, m in enumerate(self):
for i, (_, m) in enumerate(self.items()):
mask = (species_ == i)
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment