Unverified Commit 4a9944de authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Cleanup code for builtin models (#266)

parent 6a510ffc
......@@ -27,69 +27,11 @@ shouldn't be used anymore.
"""
import torch
import warnings
from pkg_resources import resource_filename
from . import neurochem
from .aev import AEVComputer
# Future: Delete BuiltinModels in a future release, it is DEPRECATED
class BuiltinModels(torch.nn.Module):
"""BuiltinModels class.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
"""
def __init__(self, builtin_class):
warnings.warn(
"BuiltinsModels is deprecated and will be deleted in"
"the future; use torchani.models.BuiltinNet()", DeprecationWarning)
super(BuiltinModels, self).__init__()
self.builtins = builtin_class()
self.aev_computer = self.builtins.aev_computer
self.neural_networks = self.builtins.models
self.energy_shifter = self.builtins.energy_shifter
def forward(self, species_coordinates):
species_aevs = self.aev_computer(species_coordinates)
species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies)
def __getitem__(self, index):
ret = torch.nn.Sequential(self.aev_computer,
self.neural_networks[index],
self.energy_shifter)
def ase(**kwargs):
from . import ase
return ase.Calculator(self.builtins.species, self.aev_computer,
self.neural_networks[index],
self.energy_shifter, **kwargs)
ret.ase = ase
ret.species_to_tensor = self.builtins.consts.species_to_tensor
return ret
def __len__(self):
return len(self.neural_networks)
def ase(self, **kwargs):
"""Get an ASE Calculator using this model"""
from . import ase
return ase.Calculator(self.builtins.species, self.aev_computer,
self.neural_networks, self.energy_shifter,
**kwargs)
def species_to_tensor(self, *args, **kwargs):
"""Convert species from strings to tensor.
See also :method:`torchani.neurochem.Constant.species_to_tensor`"""
return self.builtins.consts.species_to_tensor(*args, **kwargs) \
.to(self.aev_computer.ShfR.device)
class BuiltinNet(torch.nn.Module):
"""Private template for the builtin ANI ensemble models.
......@@ -117,16 +59,25 @@ class BuiltinNet(torch.nn.Module):
neural_networks (:class:`torchani.Ensemble`): Ensemble of ANIModel networks
"""
def __init__(self, parent_name, const_file_path, sae_file_path,
ensemble_size, ensemble_prefix_path):
def __init__(self, info_file):
super(BuiltinNet, self).__init__()
self.const_file = resource_filename(parent_name, const_file_path)
self.sae_file = resource_filename(parent_name, sae_file_path)
self.ensemble_prefix = resource_filename(parent_name,
ensemble_prefix_path)
package_name = '.'.join(__name__.split('.')[:-1])
info_file = 'resources/' + info_file
self.info_file = resource_filename(package_name, info_file)
with open(self.info_file) as f:
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file_path = 'resources/' + const_file_path
sae_file_path = 'resources/' + sae_file_path
ensemble_prefix_path = 'resources/' + ensemble_prefix_path
ensemble_size = int(ensemble_size)
self.const_file = resource_filename(package_name, const_file_path)
self.sae_file = resource_filename(package_name, sae_file_path)
self.ensemble_prefix = resource_filename(package_name, ensemble_prefix_path)
self.ensemble_size = ensemble_size
self.consts = neurochem.Constants(self.const_file)
self.species = self.consts.species
self.aev_computer = AEVComputer(**self.consts)
......@@ -234,13 +185,7 @@ class ANI1x(BuiltinNet):
"""
def __init__(self):
super(ANI1x, self).__init__(
parent_name='.'.join(__name__.split('.')[:-1]),
const_file_path='resources/ani-1x_8x'
'/rHCNO-5.2R_16-3.5A_a4-8.params',
sae_file_path='resources/ani-1x_8x/sae_linfit.dat',
ensemble_size=8,
ensemble_prefix_path='resources/ani-1x_8x/train')
super(ANI1x, self).__init__('ani-1x_8x.info')
class ANI1ccx(BuiltinNet):
......@@ -260,10 +205,4 @@ class ANI1ccx(BuiltinNet):
"""
def __init__(self):
super(ANI1ccx, self).__init__(
parent_name='.'.join(__name__.split('.')[:-1]),
const_file_path='resources/ani-1ccx_8x'
'/rHCNO-5.2R_16-3.5A_a4-8.params',
sae_file_path='resources/ani-1ccx_8x/sae_linfit.dat',
ensemble_size=8,
ensemble_prefix_path='resources/ani-1ccx_8x/train')
super(ANI1ccx, self).__init__('ani-1ccx_8x.info')
# -*- coding: utf-8 -*-
"""Tools for loading/running NeuroChem input files."""
import pkg_resources
import torch
import os
import bz2
......@@ -262,144 +261,6 @@ def load_model_ensemble(species, prefix, count):
return Ensemble(models)
# Future: Delete BuiltinsAbstract in a future release, it is DEPRECATED
class BuiltinsAbstract(object):
"""Base class for loading ANI neural network from configuration files.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
Arguments:
parent_name (:class:`str`): Base path that other paths are relative to.
const_file_path (:class:`str`): Path to constant file for ANI model(s).
sae_file_path (:class:`str`): Path to sae file for ANI model(s).
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix_path (:class:`str`): Path to prefix of directories of
models.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def __init__(
self,
parent_name,
const_file_path,
sae_file_path,
ensemble_size,
ensemble_prefix_path):
self.const_file = pkg_resources.resource_filename(
parent_name,
const_file_path)
warnings.warn(
"BuiltinsAbstract is deprecated and will be deleted in"
"the future; use torchani.models.BuiltinNet()", DeprecationWarning)
self.consts = Constants(self.const_file)
self.species = self.consts.species
self.aev_computer = AEVComputer(**self.consts)
self.sae_file = pkg_resources.resource_filename(
parent_name,
sae_file_path)
self.energy_shifter = load_sae(self.sae_file)
self.ensemble_size = ensemble_size
self.ensemble_prefix = pkg_resources.resource_filename(
parent_name,
ensemble_prefix_path)
self.models = load_model_ensemble(self.consts.species,
self.ensemble_prefix,
self.ensemble_size)
# Future: Delete Builtins in a future release, it is DEPRECATED
class Builtins(BuiltinsAbstract):
"""Container for the builtin ANI-1x model.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def __init__(self):
warnings.warn(
"Builtins is deprecated and will be deleted in the"
"future; use torchani.models.ANI1x()", DeprecationWarning)
parent_name = '.'.join(__name__.split('.')[:-1])
const_file_path = 'resources/ani-1x_8x'\
'/rHCNO-5.2R_16-3.5A_a4-8.params'
sae_file_path = 'resources/ani-1x_8x/sae_linfit.dat'
ensemble_size = 8
ensemble_prefix_path = 'resources/ani-1x_8x/train'
super(Builtins, self).__init__(
parent_name,
const_file_path,
sae_file_path,
ensemble_size,
ensemble_prefix_path
)
# Future: Delete BuiltinsANI1CCX in a future release, it is DEPRECATED
class BuiltinsANI1CCX(BuiltinsAbstract):
"""Container for the builtin ANI-1ccx model.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def __init__(self):
warnings.warn(
"BuiltinsANICCX is deprecated and will be deleted in the"
"future; use torchani.models.ANI1ccx()", DeprecationWarning)
parent_name = '.'.join(__name__.split('.')[:-1])
const_file_path = 'resources/ani-1ccx_8x'\
'/rHCNO-5.2R_16-3.5A_a4-8.params'
sae_file_path = 'resources/ani-1ccx_8x/sae_linfit.dat'
ensemble_size = 8
ensemble_prefix_path = 'resources/ani-1ccx_8x/train'
super(BuiltinsANI1CCX, self).__init__(
parent_name,
const_file_path,
sae_file_path,
ensemble_size,
ensemble_prefix_path
)
def hartree2kcal(x):
return 627.509 * x
......@@ -861,5 +722,4 @@ if sys.version_info[0] > 2:
lr *= self.lr_decay
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble',
'Builtins', 'Trainer']
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', 'Trainer']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment