Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
f9db30c3
Unverified
Commit
f9db30c3
authored
Oct 27, 2018
by
Gao, Xiang
Committed by
GitHub
Oct 27, 2018
Browse files
ASE calculator (#121)
parent
84fc8d80
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
12 deletions
+83
-12
docs/api.rst
docs/api.rst
+4
-0
torchani/aev.py
torchani/aev.py
+5
-1
torchani/ase.py
torchani/ase.py
+47
-0
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+6
-10
torchani/utils.py
torchani/utils.py
+21
-1
No files found.
docs/api.rst
View file @
f9db30c3
...
...
@@ -27,6 +27,8 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
NeuroChem
...
...
@@ -51,6 +53,8 @@ ASE Interface
.. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator
:members:
Ignite Helpers
==============
...
...
torchani/aev.py
View file @
f9db30c3
...
...
@@ -65,7 +65,11 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
...
...
torchani/ase.py
View file @
f9db30c3
...
...
@@ -9,6 +9,8 @@ import math
import
torch
import
ase.neighborlist
from
.
import
utils
import
ase.calculators.calculator
import
ase.units
class
NeighborList
:
...
...
@@ -80,3 +82,48 @@ class NeighborList:
return
neighbor_species
.
permute
(
0
,
2
,
1
),
\
neighbor_distances
.
permute
(
0
,
2
,
1
),
\
neighbor_vecs
.
permute
(
0
,
2
,
1
,
3
)
class
Calculator
(
ase
.
calculators
.
calculator
.
Calculator
):
"""TorchANI calculator for ASE
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
aev_computer (:class:`torchani.AEVComputer`): AEV computer.
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
"""
def
__init__
(
self
,
species
,
aev_computer
,
model
,
energy_shifter
):
self
.
species_to_tensor
=
utils
.
ChemicalSymbolsToInts
(
species
)
self
.
aev_computer
=
aev_computer
self
.
model
=
model
self
.
energy_shifter
=
energy_shifter
self
.
device
=
self
.
aev_computer
.
EtaR
.
device
self
.
dtype
=
self
.
aev_computer
.
EtaR
.
dtype
self
.
whole
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
model
,
self
.
energy_shifter
)
def
calculate
(
self
,
atoms
=
None
,
properties
=
[
'energy'
],
system_changes
=
ase
.
calculators
.
calculator
.
all_changes
):
super
(
Calculator
,
self
).
calculate
(
atoms
,
properties
,
system_changes
)
self
.
aev_computer
.
neighbor_list
=
NeighborList
(
cell
=
self
.
atoms
.
get_cell
(),
pbc
=
self
.
atoms
.
get_pbc
())
species
=
self
.
species_to_tensor
(
self
.
atoms
.
get_chemical_symbols
())
coordinates
=
self
.
atoms
.
get_positions
(
wrap
=
True
).
unsqueeze
(
0
)
coordinates
=
torch
.
tensor
(
coordinates
,
device
=
self
.
device
,
dtype
=
self
.
dtype
,
requires_grad
=
(
'forces'
in
properties
))
_
,
energy
=
self
.
whole
((
species
,
coordinates
))
*
ase
.
units
.
Hartree
self
.
results
[
'energy'
]
=
energy
.
item
()
if
'forces'
in
properties
:
forces
=
-
torch
.
autograd
.
grad
(
energy
.
squeeze
(),
coordinates
)[
0
]
self
.
results
[
'forces'
]
=
forces
.
item
()
torchani/neurochem/__init__.py
View file @
f9db30c3
...
...
@@ -13,7 +13,7 @@ import math
import
timeit
from
collections.abc
import
Mapping
from
..nn
import
ANIModel
,
Ensemble
,
Gaussian
from
..utils
import
EnergyShifter
from
..utils
import
EnergyShifter
,
ChemicalSymbolsToInts
from
..aev
import
AEVComputer
from
..ignite
import
Container
,
MSELoss
,
TransformedLoss
,
RMSEMetric
,
MAEMetric
...
...
@@ -21,6 +21,10 @@ from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
class
Constants
(
Mapping
):
"""NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
Attributes:
species_to_tensor (:class:`ChemicalSymbolsToInts`): call to convert
string chemical symbols to 1d long tensor.
"""
def
__init__
(
self
,
filename
):
...
...
@@ -45,10 +49,7 @@ class Constants(Mapping):
except
Exception
:
raise
ValueError
(
'unable to parse const file'
)
self
.
num_species
=
len
(
self
.
species
)
self
.
rev_species
=
{}
for
i
in
range
(
len
(
self
.
species
)):
s
=
self
.
species
[
i
]
self
.
rev_species
[
s
]
=
i
self
.
species_to_tensor
=
ChemicalSymbolsToInts
(
self
.
species
)
def
__iter__
(
self
):
yield
'Rcr'
...
...
@@ -67,11 +68,6 @@ class Constants(Mapping):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
def
species_to_tensor
(
self
,
species
):
"""Convert species from squence of strings to 1D tensor"""
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
)
def
load_sae
(
filename
):
"""Returns an object of :class:`EnergyShifter` with self energies from
...
...
torchani/utils.py
View file @
f9db30c3
...
...
@@ -151,5 +151,25 @@ class EnergyShifter(torch.nn.Module):
return
species
,
energies
+
sae
class
ChemicalSymbolsToInts
:
"""Helper that can be called to convert chemical symbol string to integers
Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
"""
def
__init__
(
self
,
all_species
):
self
.
rev_species
=
{}
for
i
in
range
(
len
(
all_species
)):
s
=
all_species
[
i
]
self
.
rev_species
[
s
]
=
i
def
__call__
(
self
,
species
):
"""Convert species from squence of strings to 1D tensor"""
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
)
__all__
=
[
'pad'
,
'pad_coordinates'
,
'present_species'
,
'strip_redundant_padding'
]
'strip_redundant_padding'
,
'ChemicalSymbolsToInts'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment