Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
4a9944de
Unverified
Commit
4a9944de
authored
Jul 30, 2019
by
Gao, Xiang
Committed by
GitHub
Jul 30, 2019
Browse files
Cleanup code for builtin models (#266)
parent
6a510ffc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
221 deletions
+20
-221
torchani/models.py
torchani/models.py
+19
-80
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+1
-141
No files found.
torchani/models.py
View file @
4a9944de
...
...
@@ -27,69 +27,11 @@ shouldn't be used anymore.
"""
import
torch
import
warnings
from
pkg_resources
import
resource_filename
from
.
import
neurochem
from
.aev
import
AEVComputer
# Future: Delete BuiltinModels in a future release, it is DEPRECATED
class
BuiltinModels
(
torch
.
nn
.
Module
):
"""BuiltinModels class.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
"""
def
__init__
(
self
,
builtin_class
):
warnings
.
warn
(
"BuiltinsModels is deprecated and will be deleted in"
"the future; use torchani.models.BuiltinNet()"
,
DeprecationWarning
)
super
(
BuiltinModels
,
self
).
__init__
()
self
.
builtins
=
builtin_class
()
self
.
aev_computer
=
self
.
builtins
.
aev_computer
self
.
neural_networks
=
self
.
builtins
.
models
self
.
energy_shifter
=
self
.
builtins
.
energy_shifter
def
forward
(
self
,
species_coordinates
):
species_aevs
=
self
.
aev_computer
(
species_coordinates
)
species_energies
=
self
.
neural_networks
(
species_aevs
)
return
self
.
energy_shifter
(
species_energies
)
def
__getitem__
(
self
,
index
):
ret
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
neural_networks
[
index
],
self
.
energy_shifter
)
def
ase
(
**
kwargs
):
from
.
import
ase
return
ase
.
Calculator
(
self
.
builtins
.
species
,
self
.
aev_computer
,
self
.
neural_networks
[
index
],
self
.
energy_shifter
,
**
kwargs
)
ret
.
ase
=
ase
ret
.
species_to_tensor
=
self
.
builtins
.
consts
.
species_to_tensor
return
ret
def
__len__
(
self
):
return
len
(
self
.
neural_networks
)
def
ase
(
self
,
**
kwargs
):
"""Get an ASE Calculator using this model"""
from
.
import
ase
return
ase
.
Calculator
(
self
.
builtins
.
species
,
self
.
aev_computer
,
self
.
neural_networks
,
self
.
energy_shifter
,
**
kwargs
)
def
species_to_tensor
(
self
,
*
args
,
**
kwargs
):
"""Convert species from strings to tensor.
See also :method:`torchani.neurochem.Constant.species_to_tensor`"""
return
self
.
builtins
.
consts
.
species_to_tensor
(
*
args
,
**
kwargs
)
\
.
to
(
self
.
aev_computer
.
ShfR
.
device
)
class
BuiltinNet
(
torch
.
nn
.
Module
):
"""Private template for the builtin ANI ensemble models.
...
...
@@ -117,16 +59,25 @@ class BuiltinNet(torch.nn.Module):
neural_networks (:class:`torchani.Ensemble`): Ensemble of ANIModel networks
"""
def
__init__
(
self
,
parent_name
,
const_file_path
,
sae_file_path
,
ensemble_size
,
ensemble_prefix_path
):
def
__init__
(
self
,
info_file
):
super
(
BuiltinNet
,
self
).
__init__
()
self
.
const_file
=
resource_filename
(
parent_name
,
const_file_path
)
self
.
sae_file
=
resource_filename
(
parent_name
,
sae_file_path
)
self
.
ensemble_prefix
=
resource_filename
(
parent_name
,
ensemble_prefix_path
)
package_name
=
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
])
info_file
=
'resources/'
+
info_file
self
.
info_file
=
resource_filename
(
package_name
,
info_file
)
with
open
(
self
.
info_file
)
as
f
:
lines
=
[
x
.
strip
()
for
x
in
f
.
readlines
()][:
4
]
const_file_path
,
sae_file_path
,
ensemble_prefix_path
,
ensemble_size
=
lines
const_file_path
=
'resources/'
+
const_file_path
sae_file_path
=
'resources/'
+
sae_file_path
ensemble_prefix_path
=
'resources/'
+
ensemble_prefix_path
ensemble_size
=
int
(
ensemble_size
)
self
.
const_file
=
resource_filename
(
package_name
,
const_file_path
)
self
.
sae_file
=
resource_filename
(
package_name
,
sae_file_path
)
self
.
ensemble_prefix
=
resource_filename
(
package_name
,
ensemble_prefix_path
)
self
.
ensemble_size
=
ensemble_size
self
.
consts
=
neurochem
.
Constants
(
self
.
const_file
)
self
.
species
=
self
.
consts
.
species
self
.
aev_computer
=
AEVComputer
(
**
self
.
consts
)
...
...
@@ -234,13 +185,7 @@ class ANI1x(BuiltinNet):
"""
def
__init__
(
self
):
super
(
ANI1x
,
self
).
__init__
(
parent_name
=
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
]),
const_file_path
=
'resources/ani-1x_8x'
'/rHCNO-5.2R_16-3.5A_a4-8.params'
,
sae_file_path
=
'resources/ani-1x_8x/sae_linfit.dat'
,
ensemble_size
=
8
,
ensemble_prefix_path
=
'resources/ani-1x_8x/train'
)
super
(
ANI1x
,
self
).
__init__
(
'ani-1x_8x.info'
)
class
ANI1ccx
(
BuiltinNet
):
...
...
@@ -260,10 +205,4 @@ class ANI1ccx(BuiltinNet):
"""
def
__init__
(
self
):
super
(
ANI1ccx
,
self
).
__init__
(
parent_name
=
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
]),
const_file_path
=
'resources/ani-1ccx_8x'
'/rHCNO-5.2R_16-3.5A_a4-8.params'
,
sae_file_path
=
'resources/ani-1ccx_8x/sae_linfit.dat'
,
ensemble_size
=
8
,
ensemble_prefix_path
=
'resources/ani-1ccx_8x/train'
)
super
(
ANI1ccx
,
self
).
__init__
(
'ani-1ccx_8x.info'
)
torchani/neurochem/__init__.py
View file @
4a9944de
# -*- coding: utf-8 -*-
"""Tools for loading/running NeuroChem input files."""
import
pkg_resources
import
torch
import
os
import
bz2
...
...
@@ -262,144 +261,6 @@ def load_model_ensemble(species, prefix, count):
return
Ensemble
(
models
)
# Future: Delete BuiltinsAbstract in a future release, it is DEPRECATED
class
BuiltinsAbstract
(
object
):
"""Base class for loading ANI neural network from configuration files.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
Arguments:
parent_name (:class:`str`): Base path that other paths are relative to.
const_file_path (:class:`str`): Path to constant file for ANI model(s).
sae_file_path (:class:`str`): Path to sae file for ANI model(s).
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix_path (:class:`str`): Path to prefix of directories of
models.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def
__init__
(
self
,
parent_name
,
const_file_path
,
sae_file_path
,
ensemble_size
,
ensemble_prefix_path
):
self
.
const_file
=
pkg_resources
.
resource_filename
(
parent_name
,
const_file_path
)
warnings
.
warn
(
"BuiltinsAbstract is deprecated and will be deleted in"
"the future; use torchani.models.BuiltinNet()"
,
DeprecationWarning
)
self
.
consts
=
Constants
(
self
.
const_file
)
self
.
species
=
self
.
consts
.
species
self
.
aev_computer
=
AEVComputer
(
**
self
.
consts
)
self
.
sae_file
=
pkg_resources
.
resource_filename
(
parent_name
,
sae_file_path
)
self
.
energy_shifter
=
load_sae
(
self
.
sae_file
)
self
.
ensemble_size
=
ensemble_size
self
.
ensemble_prefix
=
pkg_resources
.
resource_filename
(
parent_name
,
ensemble_prefix_path
)
self
.
models
=
load_model_ensemble
(
self
.
consts
.
species
,
self
.
ensemble_prefix
,
self
.
ensemble_size
)
# Future: Delete Builtins in a future release, it is DEPRECATED
class
Builtins
(
BuiltinsAbstract
):
"""Container for the builtin ANI-1x model.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def
__init__
(
self
):
warnings
.
warn
(
"Builtins is deprecated and will be deleted in the"
"future; use torchani.models.ANI1x()"
,
DeprecationWarning
)
parent_name
=
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
])
const_file_path
=
'resources/ani-1x_8x'
\
'/rHCNO-5.2R_16-3.5A_a4-8.params'
sae_file_path
=
'resources/ani-1x_8x/sae_linfit.dat'
ensemble_size
=
8
ensemble_prefix_path
=
'resources/ani-1x_8x/train'
super
(
Builtins
,
self
).
__init__
(
parent_name
,
const_file_path
,
sae_file_path
,
ensemble_size
,
ensemble_prefix_path
)
# Future: Delete BuiltinsANI1CCX in a future release, it is DEPRECATED
class
BuiltinsANI1CCX
(
BuiltinsAbstract
):
"""Container for the builtin ANI-1ccx model.
.. warning::
This class is part of an old API. It is DEPRECATED and may be deleted in a
future version. It shouldn't be used.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def
__init__
(
self
):
warnings
.
warn
(
"BuiltinsANICCX is deprecated and will be deleted in the"
"future; use torchani.models.ANI1ccx()"
,
DeprecationWarning
)
parent_name
=
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
])
const_file_path
=
'resources/ani-1ccx_8x'
\
'/rHCNO-5.2R_16-3.5A_a4-8.params'
sae_file_path
=
'resources/ani-1ccx_8x/sae_linfit.dat'
ensemble_size
=
8
ensemble_prefix_path
=
'resources/ani-1ccx_8x/train'
super
(
BuiltinsANI1CCX
,
self
).
__init__
(
parent_name
,
const_file_path
,
sae_file_path
,
ensemble_size
,
ensemble_prefix_path
)
def
hartree2kcal
(
x
):
return
627.509
*
x
...
...
@@ -861,5 +722,4 @@ if sys.version_info[0] > 2:
lr
*=
self
.
lr_decay
__all__
=
[
'Constants'
,
'load_sae'
,
'load_model'
,
'load_model_ensemble'
,
'Builtins'
,
'Trainer'
]
__all__
=
[
'Constants'
,
'load_sae'
,
'load_model'
,
'load_model_ensemble'
,
'Trainer'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment