Unverified Commit a5bad5c1 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Allow easily loading individual models (#476)

* Allow easily loading individual models

* Hopefully make the tests a tiny bit faster by not loading the whole network each time
parent 3a043d45
......@@ -25,7 +25,7 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase):
def setUp(self):
self.model = torchani.models.ANI1x().double()[0]
self.model = torchani.models.ANI1x(model_index=0).double()
def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True)
......
......@@ -13,10 +13,10 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0]
self.energy_shifter = ani1x.energy_shifter
model = torchani.models.ANI1x(model_index=0)
self.aev_computer = model.aev_computer
self.nnp = model.neural_networks
self.energy_shifter = model.energy_shifter
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
......
......@@ -12,9 +12,9 @@ class TestForce(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0]
model = torchani.models.ANI1x(model_index=0)
self.aev_computer = model.aev_computer
self.nnp = model.neural_networks
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp)
def random_skip(self):
......
......@@ -15,8 +15,7 @@ class TestStructureOptimization(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-6
self.ani1x = torchani.models.ANI1x()
self.calculator = self.ani1x[0].ase()
self.calculator = torchani.models.ANI1x(model_index=0).ase()
def testRMSE(self):
datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all')
......
......@@ -26,7 +26,7 @@ directly calculate energies or get an ASE calculator. For example:
Note that the class BuiltinModels can be accessed but it is deprecated and
shouldn't be used anymore.
"""
import os
import torch
from torch import Tensor
from typing import Tuple, Optional
......@@ -53,6 +53,45 @@ class BuiltinModel(torch.nn.Module):
self.consts = consts
self.sae_dict = sae_dict
@classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False, model_index=0):
# this is used to load only 1 model (by default model 0)
consts, sae_file, ensemble_prefix, ensemble_size = cls._parse_neurochem_resources(info_file_path)
if (model_index >= ensemble_size):
raise ValueError("The ensemble size is only {}, model {} can't be loaded".format(ensemble_size, model_index))
species_converter = SpeciesConverter(consts.species)
aev_computer = AEVComputer(**consts)
energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True)
species_to_tensor = consts.species_to_tensor
network_dir = os.path.join('{}{}'.format(ensemble_prefix, model_index), 'networks')
neural_networks = neurochem.load_model(consts.species, network_dir)
return cls(species_converter, aev_computer, neural_networks,
energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index)
@staticmethod
def _parse_neurochem_resources(info_file_path):
def get_resource(file_path):
package_name = '.'.join(__name__.split('.')[:-1])
return resource_filename(package_name, 'resources/' + file_path)
info_file = get_resource(info_file_path)
with open(info_file) as f:
# const_file: Path to the file with the builtin constants.
# sae_file: Path to the file with the Self Atomic Energies.
# ensemble_prefix: Prefix of the neurochem resource directories.
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file = get_resource(const_file_path)
sae_file = get_resource(sae_file_path)
ensemble_prefix = get_resource(ensemble_prefix_path)
ensemble_size = int(ensemble_size)
consts = neurochem.Constants(const_file)
return consts, sae_file, ensemble_prefix, ensemble_size
def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
......@@ -159,31 +198,15 @@ class BuiltinEnsemble(BuiltinModel):
@classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False):
def get_resource(file_path):
package_name = '.'.join(__name__.split('.')[:-1])
return resource_filename(package_name, 'resources/' + file_path)
info_file = get_resource(info_file_path)
with open(info_file) as f:
# const_file: Path to the file with the builtin constants.
# sae_file: Path to the file with the Self Atomic Energies.
# ensemble_prefix: Prefix of the neurochem resource directories.
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file = get_resource(const_file_path)
sae_file = get_resource(sae_file_path)
ensemble_prefix = get_resource(ensemble_prefix_path)
ensemble_size = int(ensemble_size)
consts = neurochem.Constants(const_file)
# this is used to load only 1 model (by default model 0)
consts, sae_file, ensemble_prefix, ensemble_size = cls._parse_neurochem_resources(info_file_path)
species_converter = SpeciesConverter(consts.species)
aev_computer = AEVComputer(**consts)
neural_networks = neurochem.load_model_ensemble(consts.species,
ensemble_prefix, ensemble_size)
energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True)
species_to_tensor = consts.species_to_tensor
neural_networks = neurochem.load_model_ensemble(consts.species,
ensemble_prefix, ensemble_size)
return cls(species_converter, aev_computer, neural_networks,
energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index)
......@@ -220,7 +243,7 @@ class BuiltinEnsemble(BuiltinModel):
return len(self.neural_networks)
def ANI1x(periodic_table_index=False):
def ANI1x(periodic_table_index=False, model_index=None):
"""The ANI-1x model as in `ani-1x_8x on GitHub`_ and `Active Learning Paper`_.
The ANI-1x model is an ensemble of 8 networks that was trained using
......@@ -234,10 +257,13 @@ def ANI1x(periodic_table_index=False):
.. _Active Learning Paper:
https://aip.scitation.org/doi/abs/10.1063/1.5023802
"""
return BuiltinEnsemble._from_neurochem_resources('ani-1x_8x.info', periodic_table_index)
info_file = 'ani-1x_8x.info'
if model_index is None:
return BuiltinEnsemble._from_neurochem_resources(info_file, periodic_table_index)
return BuiltinModel._from_neurochem_resources(info_file, periodic_table_index, model_index)
def ANI1ccx(periodic_table_index=False):
def ANI1ccx(periodic_table_index=False, model_index=None):
"""The ANI-1ccx model as in `ani-1ccx_8x on GitHub`_ and `Transfer Learning Paper`_.
The ANI-1ccx model is an ensemble of 8 networks that was trained
......@@ -252,4 +278,7 @@ def ANI1ccx(periodic_table_index=False):
.. _Transfer Learning Paper:
https://doi.org/10.26434/chemrxiv.6744440.v1
"""
return BuiltinEnsemble._from_neurochem_resources('ani-1ccx_8x.info', periodic_table_index)
info_file = 'ani-1ccx_8x.info'
if model_index is None:
return BuiltinEnsemble._from_neurochem_resources(info_file, periodic_table_index)
return BuiltinModel._from_neurochem_resources(info_file, periodic_table_index, model_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment