Unverified Commit e2e740d0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

completely decouple aev computer and neural networks (#40)

parent ee31b752
......@@ -9,8 +9,10 @@ sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file, device=device)
nn = torchani.models.NeuroChemNNP(aev_computer, derivative=True,
from_=network_dir, ensemble=8)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
model = torch.nn.Sequential(prepare, aev_computer, nn)
shift_energy = torchani.EnergyShifter(sae_file)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
......@@ -23,7 +25,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H']
energy = nn(coordinates, species)
energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
energy = shift_energy.add_sae(energy, species)
force = -derivative
......
......@@ -9,9 +9,8 @@ def celu(x, alpha):
class AtomicNetwork(torch.nn.Module):
def __init__(self, aev_computer):
def __init__(self):
super(AtomicNetwork, self).__init__()
self.aev_computer = aev_computer
self.output_length = 1
self.layer1 = torch.nn.Linear(384, 128)
self.layer2 = torch.nn.Linear(128, 128)
......@@ -33,16 +32,23 @@ class AtomicNetwork(torch.nn.Module):
def get_or_create_model(filename, benchmark=False,
device=torchani.default_device):
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device)
model = torchani.models.CustomModel(
aev_computer,
reducer=torch.sum,
benchmark=benchmark,
per_species={
'C': AtomicNetwork(aev_computer),
'H': AtomicNetwork(aev_computer),
'N': AtomicNetwork(aev_computer),
'O': AtomicNetwork(aev_computer),
'C': AtomicNetwork(),
'H': AtomicNetwork(),
'N': AtomicNetwork(),
'O': AtomicNetwork(),
})
class Flatten(torch.nn.Module):
def forward(self, x):
return x.flatten()
model = torch.nn.Sequential(prepare, aev_computer, model, Flatten())
if os.path.isfile(filename):
model.load_state_dict(torch.load(filename))
else:
......
......@@ -24,23 +24,11 @@ training, validation, testing = torchani.data.load_or_create(
transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks)
nnp = model.get_or_create_model(model_checkpoint)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
batch_nnp = torchani.models.BatchModel(Flatten(nnp))
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
......
......@@ -15,19 +15,7 @@ dataset = torchani.data.ANIDataset(
transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
batch_nnp = torchani.models.BatchModel(Flatten(nnp))
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
......@@ -53,14 +41,13 @@ def finalize_tqdm(trainer):
start = timeit.default_timer()
trainer.run(dataloader, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', nnp.aev_computer.timers['radial terms'])
print('Angular terms:', nnp.aev_computer.timers['angular terms'])
print('Terms and indices:', nnp.aev_computer.timers['terms and indices'])
print('Combinations:', nnp.aev_computer.timers['combinations'])
print('Mask R:', nnp.aev_computer.timers['mask_r'])
print('Mask A:', nnp.aev_computer.timers['mask_a'])
print('Assemble:', nnp.aev_computer.timers['assemble'])
print('Total AEV:', nnp.aev_computer.timers['total'])
print('NN:', nnp.timers['nn'])
print('Total Forward:', nnp.timers['forward'])
print('Radial terms:', nnp[1].timers['radial terms'])
print('Angular terms:', nnp[1].timers['angular terms'])
print('Terms and indices:', nnp[1].timers['terms and indices'])
print('Combinations:', nnp[1].timers['combinations'])
print('Mask R:', nnp[1].timers['mask_r'])
print('Mask A:', nnp[1].timers['mask_a'])
print('Assemble:', nnp[1].timers['assemble'])
print('Total AEV:', nnp[1].timers['total'])
print('NN:', nnp[2].timers['forward'])
print('Epoch time:', elapsed)
......@@ -11,15 +11,20 @@ N = 97
class TestAEV(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype):
self.aev = torchani.SortedAEV(dtype=dtype, device=torch.device('cpu'))
aev_computer = torchani.SortedAEV(dtype=dtype,
device=torch.device('cpu'))
self.radial_length = aev_computer.radial_length
self.aev = torch.nn.Sequential(
torchani.PrepareInput(aev_computer.species, aev_computer.device),
aev_computer
)
self.tolerance = 1e-5
def _test_molecule(self, coordinates, species, expected_radial,
expected_angular):
species = self.aev.species_to_tensor(species)
aev = self.aev((coordinates, species))
radial = aev[..., :self.aev.radial_length]
angular = aev[..., self.aev.radial_length:]
species, aev = self.aev((species, coordinates))
radial = aev[..., :self.radial_length]
angular = aev[..., self.radial_length:]
radial_diff = expected_radial - radial
radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular
......
......@@ -21,8 +21,11 @@ if sys.version_info.major >= 3:
ds = torchani.data.ANIDataset(path, chunksize, device=device)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
batch_nnp = torchani.models.BatchModel(nnp)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
model = torch.nn.Sequential(prepare, aev_computer, nnp)
batch_nnp = torchani.models.BatchModel(model)
for batch_input, batch_output in itertools.islice(loader, 10):
batch_output_ = batch_nnp(batch_input).squeeze()
self.assertListEqual(list(batch_output_.shape),
......
......@@ -16,7 +16,7 @@ class TestBenchmark(unittest.TestCase):
self.conformations, 8, 3, dtype=dtype, device=device)
self.count = 100
def _testModule(self, module, asserts):
def _testModule(self, run_module, result_module, asserts):
keys = []
for i in asserts:
if '>=' in i:
......@@ -36,57 +36,57 @@ class TestBenchmark(unittest.TestCase):
keys += [i[0].strip(), i[1].strip()]
else:
keys.append(i.strip())
self.assertEqual(set(module.timers.keys()), set(keys))
self.assertEqual(set(result_module.timers.keys()), set(keys))
for i in keys:
self.assertEqual(module.timers[i], 0)
old_timers = copy.copy(module.timers)
self.assertEqual(result_module.timers[i], 0)
old_timers = copy.copy(result_module.timers)
for _ in range(self.count):
if isinstance(module, torchani.aev.AEVComputer):
species = module.species_to_tensor(self.species)
module((self.coordinates, species))
else:
module(self.coordinates, self.species)
run_module((self.species, self.coordinates))
for i in keys:
self.assertLess(old_timers[i], module.timers[i])
self.assertLess(old_timers[i], result_module.timers[i])
for i in asserts:
if '>=' in i:
i = i.split('>=')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertGreaterEqual(
module.timers[key0], module.timers[key1])
result_module.timers[key0], result_module.timers[key1])
elif '<=' in i:
i = i.split('<=')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertLessEqual(
module.timers[key0], module.timers[key1])
result_module.timers[key0], result_module.timers[key1])
elif '>' in i:
i = i.split('>')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertGreater(
module.timers[key0], module.timers[key1])
result_module.timers[key0], result_module.timers[key1])
elif '<' in i:
i = i.split('<')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertLess(module.timers[key0], module.timers[key1])
self.assertLess(result_module.timers[key0],
result_module.timers[key1])
elif '=' in i:
i = i.split('=')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertEqual(module.timers[key0], module.timers[key1])
old_timers = copy.copy(module.timers)
module.reset_timers()
self.assertEqual(set(module.timers.keys()), set(keys))
self.assertEqual(result_module.timers[key0],
result_module.timers[key1])
old_timers = copy.copy(result_module.timers)
result_module.reset_timers()
self.assertEqual(set(result_module.timers.keys()), set(keys))
for i in keys:
self.assertEqual(module.timers[i], 0)
self.assertEqual(result_module.timers[i], 0)
def testAEV(self):
aev_computer = torchani.SortedAEV(
benchmark=True, dtype=self.dtype, device=self.device)
self._testModule(aev_computer, [
prepare = torchani.PrepareInput(aev_computer.species, self.device)
run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [
'terms and indices>radial terms',
'terms and indices>angular terms',
'total>terms and indices',
......@@ -97,9 +97,11 @@ class TestBenchmark(unittest.TestCase):
def testANIModel(self):
aev_computer = torchani.SortedAEV(
dtype=self.dtype, device=self.device)
prepare = torchani.PrepareInput(aev_computer.species, self.device)
model = torchani.models.NeuroChemNNP(
aev_computer, benchmark=True)
self._testModule(model, ['forward>nn'])
aev_computer.species, benchmark=True)
run_module = torch.nn.Sequential(prepare, aev_computer, model)
self._testModule(run_module, model, ['forward'])
if __name__ == '__main__':
......
......@@ -14,13 +14,16 @@ class TestEnergies(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
self.tolerance = 5e-5
self.aev_computer = torchani.SortedAEV(
aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu'))
self.nnp = torchani.models.NeuroChemNNP(self.aev_computer)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, energies):
shift_energy = torchani.EnergyShifter(torchani.buildin_sae_file)
energies_ = self.nnp(coordinates, species).squeeze()
energies_ = self.model((species, coordinates)).squeeze()
energies_ = shift_energy.add_sae(energies_, species)
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -19,15 +19,18 @@ class TestEnsemble(unittest.TestCase):
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV(device=torch.device('cpu'))
ensemble = torchani.models.NeuroChemNNP(aev, ensemble=True)
prepare = torchani.PrepareInput(aev.species, aev.device)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble)
models = [torchani.models.
NeuroChemNNP(aev, ensemble=False,
NeuroChemNNP(aev.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i))
for i in range(n)]
models = [torch.nn.Sequential(prepare, aev, m) for m in models]
energy1 = ensemble(coordinates, species)
energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m(coordinates, species) for m in models]
energy2 = [m((species, coordinates)) for m in models]
energy2 = sum(energy2) / n
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item()
......
......@@ -13,14 +13,16 @@ class TestForce(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
self.tolerance = 1e-5
self.aev_computer = torchani.SortedAEV(
aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu'))
self.nnp = torchani.models.NeuroChemNNP(
self.aev_computer)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, forces):
coordinates = torch.tensor(coordinates, requires_grad=True)
energies = self.nnp(coordinates, species)
energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -26,19 +26,16 @@ if sys.version_info.major >= 3:
ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
class Flatten(torch.nn.Module):
def forward(self, x):
return x.flatten()
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp)
model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
batch_nnp = torchani.models.BatchModel(model)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(container.parameters())
trainer = create_supervised_trainer(
......
......@@ -2,11 +2,12 @@ from .energyshifter import EnergyShifter
from . import models
from . import data
from . import ignite
from .aev import SortedAEV
from .aev import SortedAEV, PrepareInput
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble, default_dtype, default_device
__all__ = ['SortedAEV', 'EnergyShifter', 'models', 'data', 'ignite',
__all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter',
'models', 'data', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble',
'default_dtype', 'default_device']
......@@ -89,6 +89,52 @@ class AEVComputer(BenchmarkedModule):
self.ShfA = self.ShfA.view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1)
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
Parameters
----------
(species, coordinates)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
Returns
-------
(torch.Tensor, torch.LongTensor)
Returns full AEV and species
"""
raise NotImplementedError('subclass must override this method')
class PrepareInput(torch.nn.Module):
def __init__(self, species, device):
super(PrepareInput, self).__init__()
self.species = species
self.device = device
def species_to_tensor(self, species):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
Returns
-------
torch.Tensor
Long tensor for the species, where a value k means the species is
the same as self.species[k].
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=self.device)
def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species`
......@@ -110,28 +156,10 @@ class AEVComputer(BenchmarkedModule):
new_tensors.append(t.index_select(1, reverse))
return (species, *tensors)
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
Parameters
----------
(coordinates, species)
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor
of `dtype`. The radial AEV must be of shape
(conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
"""
raise NotImplementedError('subclass must override this method')
def forward(self, species_coordinates):
species, coordinates = species_coordinates
species = self.species_to_tensor(species)
return self.sort_by_species(species, coordinates)
def _cutoff_cosine(distances, cutoff):
......@@ -194,24 +222,6 @@ class SortedAEV(AEVComputer):
self.assemble = self._enable_benchmark(self.assemble, 'assemble')
self.forward = self._enable_benchmark(self.forward, 'total')
def species_to_tensor(self, species):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
Returns
-------
torch.Tensor
Long tensor for the species, where a value k means the species is
the same as self.species[k].
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=self.device)
def radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors
......@@ -482,8 +492,8 @@ class SortedAEV(AEVComputer):
return radial_aevs, torch.cat(angular_aevs, dim=2)
def forward(self, coordinates_species):
coordinates, species = coordinates_species
def forward(self, species_coordinates):
species, coordinates = species_coordinates
present_species = species.unique(sorted=True)
radial_terms, angular_terms, indices_r, indices_a = \
......@@ -497,4 +507,4 @@ class SortedAEV(AEVComputer):
radial, angular = self.assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a)
fullaev = torch.cat([radial, angular], dim=2)
return fullaev
return species, fullaev
from ..aev import AEVComputer
import torch
from ..benchmarked import BenchmarkedModule
......@@ -9,10 +8,10 @@ class ANIModel(BenchmarkedModule):
Attributes
----------
aev_computer : AEVComputer
The AEV computer.
species : list
Chemical symbol of supported atom species.
output_length : int
The length of output vector
The length of output vector.
suffixes : sequence
Different suffixes denote different models in an ensemble.
model_<X><suffix> : nn.Module
......@@ -26,30 +25,15 @@ class ANIModel(BenchmarkedModule):
the tensor containing desired output.
output_length : int
Length of output of each submodel.
derivative : boolean
Whether to support computing the derivative w.r.t coordinates,
i.e. d(output)/dR
derivative_graph : boolean
Whether to generate a graph for the derivative. This would be required
only if the derivative is included as part of the loss function.
timers : dict
Dictionary storing the the benchmark result. It has the following keys:
aev : time spent on computing AEV.
nn : time spent on computing output from AEV.
derivative : time spend on computing derivative w.r.t. coordinates
after the outputs is given. This key is only available if
derivative computation is turned on.
forward : total time for the forward pass
"""
def __init__(self, aev_computer, suffixes, reducer, output_length, models,
def __init__(self, species, suffixes, reducer, output_length, models,
benchmark=False):
super(ANIModel, self).__init__(benchmark)
if not isinstance(aev_computer, AEVComputer):
raise TypeError(
"ModelOnAEV: aev_computer must be a subclass of AEVComputer")
self.aev_computer = aev_computer
self.species = species
self.suffixes = suffixes
self.reducer = reducer
self.output_length = output_length
......@@ -57,20 +41,19 @@ class ANIModel(BenchmarkedModule):
setattr(self, i, models[i])
if benchmark:
self.aev_to_output = self._enable_benchmark(
self.aev_to_output, 'nn')
self.forward = self._enable_benchmark(self.forward, 'forward')
def aev_to_output(self, aev, species):
def forward(self, species_aev):
"""Compute output from aev
Parameters
----------
(species, aev)
species : torch.Tensor
Tensor storing the species for each atom.
aev : torch.Tensor
Pytorch tensor of shape (conformations, atoms, aev_length) storing
the computed AEVs.
species : torch.Tensor
Tensor storing the species for each atom.
Returns
-------
......@@ -78,6 +61,7 @@ class ANIModel(BenchmarkedModule):
Pytorch tensor of shape (conformations, output_length) for the
output of each conformation.
"""
species, aev = species_aev
conformations = aev.shape[0]
atoms = len(species)
rev_species = species.__reversed__()
......@@ -88,11 +72,11 @@ class ANIModel(BenchmarkedModule):
for s in species_dedup:
begin = species.index(s)
end = atoms - rev_species.index(s)
y = aev[:, begin:end, :].reshape(-1, self.aev_computer.aev_length)
y = aev[:, begin:end, :].flatten(0, 1)
def apply_model(suffix):
model_X = getattr(self, 'model_' +
self.aev_computer.species[s] + suffix)
self.species[s] + suffix)
return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys)
......@@ -102,26 +86,3 @@ class ANIModel(BenchmarkedModule):
per_species_outputs = torch.cat(per_species_outputs, dim=1)
molecule_output = self.reducer(per_species_outputs, dim=1)
return molecule_output
def forward(self, coordinates, species):
"""Feed forward
Parameters
----------
coordinates : torch.Tensor
The pytorch tensor of shape (conformations, atoms, 3) storing
the coordinates of all atoms of all conformations.
species : list of string
List of string storing the species for each atom.
Returns
-------
torch.Tensor
Tensor of shape (conformations, output_length) for the
output of each conformation.
"""
species = self.aev_computer.species_to_tensor(species)
_species, _coordinates, = self.aev_computer.sort_by_species(
species, coordinates)
aev = self.aev_computer((_coordinates, _species))
return self.aev_to_output(aev, _species)
......@@ -12,5 +12,5 @@ class BatchModel(torch.nn.Module):
for i in batch:
coordinates = i['coordinates']
species = i['species']
results.append(self.model(coordinates, species))
results.append(self.model((species, coordinates)))
return torch.cat(results)
......@@ -3,7 +3,7 @@ from .ani_model import ANIModel
class CustomModel(ANIModel):
def __init__(self, aev_computer, per_species, reducer,
def __init__(self, per_species, reducer,
derivative=False, derivative_graph=False, benchmark=False):
"""Custom single model, no ensemble
......@@ -31,7 +31,8 @@ class CustomModel(ANIModel):
raise ValueError(
'''output length of each atomic neural network must
match''')
super(CustomModel, self).__init__(aev_computer, suffixes, reducer,
output_length, models, benchmark)
super(CustomModel, self).__init__(list(per_species.keys()), suffixes,
reducer, output_length, models,
benchmark)
for i in per_species:
setattr(self, 'model_' + i, per_species[i])
......@@ -12,10 +12,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
Attributes
----------
dtype : torch.dtype
Pytorch data type for tensors
device : torch.Device
The device where tensors should be.
layers : int
Number of layers.
output_length : int
......@@ -29,13 +25,11 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
The NeuroChem index for activation.
"""
def __init__(self, dtype, device, filename):
def __init__(self, filename):
"""Initialize from NeuroChem network directory.
Parameters
----------
dtype : torch.dtype
Pytorch data type for tensors
filename : string
The file name for the `.nnf` file that store network
hyperparameters. The `.bparam` and `.wparam` must be
......@@ -43,8 +37,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
"""
super(NeuroChemAtomicNetwork, self).__init__()
self.dtype = dtype
self.device = device
networ_dir = os.path.dirname(filename)
with open(filename, 'rb') as f:
buffer = f.read()
......@@ -204,7 +196,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
raise NotImplementedError(
'''different activation on different
layers are not supported''')
linear = torch.nn.Linear(in_size, out_size).type(self.dtype)
linear = torch.nn.Linear(in_size, out_size)
name = 'layer{}'.format(i)
setattr(self, name, linear)
if in_size * out_size != wsz or out_size != bsz:
......@@ -219,14 +211,12 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
wsize = in_size * out_size
fw = open(wfn, 'rb')
w = struct.unpack('{}f'.format(wsize), fw.read())
w = torch.tensor(w, dtype=self.dtype, device=self.device).view(
out_size, in_size)
w = torch.tensor(w).view(out_size, in_size)
linear.weight.data = w
fw.close()
fb = open(bfn, 'rb')
b = struct.unpack('{}f'.format(out_size), fb.read())
b = torch.tensor(b, dtype=self.dtype,
device=self.device).view(out_size)
b = torch.tensor(b).view(out_size)
linear.bias.data = b
fb.close()
......
......@@ -7,8 +7,7 @@ from ..env import buildin_network_dir, buildin_model_prefix, buildin_ensemble
class NeuroChemNNP(ANIModel):
def __init__(self, aev_computer, from_=None, ensemble=False,
derivative=False, derivative_graph=False, benchmark=False):
def __init__(self, species, from_=None, ensemble=False, benchmark=False):
"""If from_=None then ensemble must be a boolean. If ensemble=False,
then use buildin network0, else use buildin network ensemble.
If from_ != None, ensemble must be either False or an integer
......@@ -46,12 +45,10 @@ class NeuroChemNNP(ANIModel):
models = {}
output_length = None
for network_dir, suffix in zip(network_dirs, suffixes):
for i in aev_computer.species:
for i in species:
filename = os.path.join(
network_dir, 'ANN-{}.nnf'.format(i))
model_X = NeuroChemAtomicNetwork(
aev_computer.dtype, aev_computer.device,
filename)
model_X = NeuroChemAtomicNetwork(filename)
if output_length is None:
output_length = model_X.output_length
elif output_length != model_X.output_length:
......@@ -59,5 +56,5 @@ class NeuroChemNNP(ANIModel):
'''output length of each atomic neural networt
must match''')
models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(aev_computer, suffixes, reducer,
super(NeuroChemNNP, self).__init__(species, suffixes, reducer,
output_length, models, benchmark)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment