"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d74404adf8cef2982cc6353ce6e76315b43aeb7a"
Unverified Commit a5bad5c1 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Allow easily loading individual models (#476)

* Allow easily loading individual models

* Hopefully make the tests a tiny bit faster by not loading the whole network each time
parent 3a043d45
...@@ -25,7 +25,7 @@ def get_numeric_force(atoms, eps): ...@@ -25,7 +25,7 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase): class TestASE(unittest.TestCase):
def setUp(self): def setUp(self):
self.model = torchani.models.ANI1x().double()[0] self.model = torchani.models.ANI1x(model_index=0).double()
def testWithNumericalForceWithPBCEnabled(self): def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True) atoms = Diamond(symbol="C", pbc=True)
......
...@@ -13,10 +13,10 @@ class TestEnergies(unittest.TestCase): ...@@ -13,10 +13,10 @@ class TestEnergies(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 5e-5 self.tolerance = 5e-5
ani1x = torchani.models.ANI1x() model = torchani.models.ANI1x(model_index=0)
self.aev_computer = ani1x.aev_computer self.aev_computer = model.aev_computer
self.nnp = ani1x.neural_networks[0] self.nnp = model.neural_networks
self.energy_shifter = ani1x.energy_shifter self.energy_shifter = model.energy_shifter
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter) self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter) self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
......
...@@ -12,9 +12,9 @@ class TestForce(unittest.TestCase): ...@@ -12,9 +12,9 @@ class TestForce(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 1e-5 self.tolerance = 1e-5
ani1x = torchani.models.ANI1x() model = torchani.models.ANI1x(model_index=0)
self.aev_computer = ani1x.aev_computer self.aev_computer = model.aev_computer
self.nnp = ani1x.neural_networks[0] self.nnp = model.neural_networks
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp) self.model = torchani.nn.Sequential(self.aev_computer, self.nnp)
def random_skip(self): def random_skip(self):
......
...@@ -15,8 +15,7 @@ class TestStructureOptimization(unittest.TestCase): ...@@ -15,8 +15,7 @@ class TestStructureOptimization(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 1e-6 self.tolerance = 1e-6
self.ani1x = torchani.models.ANI1x() self.calculator = torchani.models.ANI1x(model_index=0).ase()
self.calculator = self.ani1x[0].ase()
def testRMSE(self): def testRMSE(self):
datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all') datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all')
......
...@@ -26,7 +26,7 @@ directly calculate energies or get an ASE calculator. For example: ...@@ -26,7 +26,7 @@ directly calculate energies or get an ASE calculator. For example:
Note that the class BuiltinModels can be accessed but it is deprecated and Note that the class BuiltinModels can be accessed but it is deprecated and
shouldn't be used anymore. shouldn't be used anymore.
""" """
import os
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple, Optional from typing import Tuple, Optional
...@@ -53,6 +53,45 @@ class BuiltinModel(torch.nn.Module): ...@@ -53,6 +53,45 @@ class BuiltinModel(torch.nn.Module):
self.consts = consts self.consts = consts
self.sae_dict = sae_dict self.sae_dict = sae_dict
@classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False, model_index=0):
# this is used to load only 1 model (by default model 0)
consts, sae_file, ensemble_prefix, ensemble_size = cls._parse_neurochem_resources(info_file_path)
if (model_index >= ensemble_size):
raise ValueError("The ensemble size is only {}, model {} can't be loaded".format(ensemble_size, model_index))
species_converter = SpeciesConverter(consts.species)
aev_computer = AEVComputer(**consts)
energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True)
species_to_tensor = consts.species_to_tensor
network_dir = os.path.join('{}{}'.format(ensemble_prefix, model_index), 'networks')
neural_networks = neurochem.load_model(consts.species, network_dir)
return cls(species_converter, aev_computer, neural_networks,
energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index)
@staticmethod
def _parse_neurochem_resources(info_file_path):
def get_resource(file_path):
package_name = '.'.join(__name__.split('.')[:-1])
return resource_filename(package_name, 'resources/' + file_path)
info_file = get_resource(info_file_path)
with open(info_file) as f:
# const_file: Path to the file with the builtin constants.
# sae_file: Path to the file with the Self Atomic Energies.
# ensemble_prefix: Prefix of the neurochem resource directories.
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file = get_resource(const_file_path)
sae_file = get_resource(sae_file_path)
ensemble_prefix = get_resource(ensemble_prefix_path)
ensemble_size = int(ensemble_size)
consts = neurochem.Constants(const_file)
return consts, sae_file, ensemble_prefix, ensemble_size
def forward(self, species_coordinates: Tuple[Tensor, Tensor], def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
...@@ -159,31 +198,15 @@ class BuiltinEnsemble(BuiltinModel): ...@@ -159,31 +198,15 @@ class BuiltinEnsemble(BuiltinModel):
@classmethod @classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False): def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False):
# this is used to load only 1 model (by default model 0)
def get_resource(file_path): consts, sae_file, ensemble_prefix, ensemble_size = cls._parse_neurochem_resources(info_file_path)
package_name = '.'.join(__name__.split('.')[:-1])
return resource_filename(package_name, 'resources/' + file_path)
info_file = get_resource(info_file_path)
with open(info_file) as f:
# const_file: Path to the file with the builtin constants.
# sae_file: Path to the file with the Self Atomic Energies.
# ensemble_prefix: Prefix of the neurochem resource directories.
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file = get_resource(const_file_path)
sae_file = get_resource(sae_file_path)
ensemble_prefix = get_resource(ensemble_prefix_path)
ensemble_size = int(ensemble_size)
consts = neurochem.Constants(const_file)
species_converter = SpeciesConverter(consts.species) species_converter = SpeciesConverter(consts.species)
aev_computer = AEVComputer(**consts) aev_computer = AEVComputer(**consts)
neural_networks = neurochem.load_model_ensemble(consts.species,
ensemble_prefix, ensemble_size)
energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True) energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True)
species_to_tensor = consts.species_to_tensor species_to_tensor = consts.species_to_tensor
neural_networks = neurochem.load_model_ensemble(consts.species,
ensemble_prefix, ensemble_size)
return cls(species_converter, aev_computer, neural_networks, return cls(species_converter, aev_computer, neural_networks,
energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index) energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index)
...@@ -220,7 +243,7 @@ class BuiltinEnsemble(BuiltinModel): ...@@ -220,7 +243,7 @@ class BuiltinEnsemble(BuiltinModel):
return len(self.neural_networks) return len(self.neural_networks)
def ANI1x(periodic_table_index=False): def ANI1x(periodic_table_index=False, model_index=None):
"""The ANI-1x model as in `ani-1x_8x on GitHub`_ and `Active Learning Paper`_. """The ANI-1x model as in `ani-1x_8x on GitHub`_ and `Active Learning Paper`_.
The ANI-1x model is an ensemble of 8 networks that was trained using The ANI-1x model is an ensemble of 8 networks that was trained using
...@@ -234,10 +257,13 @@ def ANI1x(periodic_table_index=False): ...@@ -234,10 +257,13 @@ def ANI1x(periodic_table_index=False):
.. _Active Learning Paper: .. _Active Learning Paper:
https://aip.scitation.org/doi/abs/10.1063/1.5023802 https://aip.scitation.org/doi/abs/10.1063/1.5023802
""" """
return BuiltinEnsemble._from_neurochem_resources('ani-1x_8x.info', periodic_table_index) info_file = 'ani-1x_8x.info'
if model_index is None:
return BuiltinEnsemble._from_neurochem_resources(info_file, periodic_table_index)
return BuiltinModel._from_neurochem_resources(info_file, periodic_table_index, model_index)
def ANI1ccx(periodic_table_index=False): def ANI1ccx(periodic_table_index=False, model_index=None):
"""The ANI-1ccx model as in `ani-1ccx_8x on GitHub`_ and `Transfer Learning Paper`_. """The ANI-1ccx model as in `ani-1ccx_8x on GitHub`_ and `Transfer Learning Paper`_.
The ANI-1ccx model is an ensemble of 8 networks that was trained The ANI-1ccx model is an ensemble of 8 networks that was trained
...@@ -252,4 +278,7 @@ def ANI1ccx(periodic_table_index=False): ...@@ -252,4 +278,7 @@ def ANI1ccx(periodic_table_index=False):
.. _Transfer Learning Paper: .. _Transfer Learning Paper:
https://doi.org/10.26434/chemrxiv.6744440.v1 https://doi.org/10.26434/chemrxiv.6744440.v1
""" """
return BuiltinEnsemble._from_neurochem_resources('ani-1ccx_8x.info', periodic_table_index) info_file = 'ani-1ccx_8x.info'
if model_index is None:
return BuiltinEnsemble._from_neurochem_resources(info_file, periodic_table_index)
return BuiltinModel._from_neurochem_resources(info_file, periodic_table_index, model_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment