Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
004f5a52
Commit
004f5a52
authored
Nov 08, 2019
by
Gao, Xiang
Committed by
Farhad Ramezanghorbani
Nov 08, 2019
Browse files
Use namedtuple to improve API while still maintaining backward compatibility (#380)
* Use namedtuple to improve API * improve
parent
92c307dc
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
46 additions
and
31 deletions
+46
-31
examples/ase_interface.py
examples/ase_interface.py
+0
-1
examples/energy_force.py
examples/energy_force.py
+1
-2
examples/jit.py
examples/jit.py
+4
-4
examples/load_from_neurochem.py
examples/load_from_neurochem.py
+2
-2
examples/nnp_training.py
examples/nnp_training.py
+2
-2
examples/nnp_training_force.py
examples/nnp_training_force.py
+2
-2
examples/vibration_analysis.py
examples/vibration_analysis.py
+1
-1
torchani/aev.py
torchani/aev.py
+10
-4
torchani/ase.py
torchani/ase.py
+3
-3
torchani/nn.py
torchani/nn.py
+11
-6
torchani/utils.py
torchani/utils.py
+10
-4
No files found.
examples/ase_interface.py
View file @
004f5a52
...
...
@@ -16,7 +16,6 @@ calculator.
###############################################################################
# To begin with, let's first import the modules we will use:
from
__future__
import
print_function
from
ase.lattice.cubic
import
Diamond
from
ase.md.langevin
import
Langevin
from
ase.optimize
import
BFGS
...
...
examples/energy_force.py
View file @
004f5a52
...
...
@@ -9,7 +9,6 @@ TorchANI and can be used directly.
###############################################################################
# To begin with, let's first import the modules we will use:
from
__future__
import
print_function
import
torch
import
torchani
...
...
@@ -43,7 +42,7 @@ species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
###############################################################################
# Now let's compute energy and force:
_
,
energy
=
model
((
species
,
coordinates
))
energy
=
model
((
species
,
coordinates
))
.
energies
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
force
=
-
derivative
...
...
examples/jit.py
View file @
004f5a52
...
...
@@ -46,9 +46,9 @@ species = model.species_to_tensor('CHHHH').unsqueeze(0)
###############################################################################
# And here is the result:
_
,
energies_ensemble
=
model
((
species
,
coordinates
))
_
,
energies_single
=
model
[
0
]((
species
,
coordinates
))
_
,
energies_ensemble_jit
=
loaded_compiled_model
((
species
,
coordinates
))
_
,
energies_single_jit
=
loaded_compiled_model0
((
species
,
coordinates
))
energies_ensemble
=
model
((
species
,
coordinates
))
.
energies
energies_single
=
model
[
0
]((
species
,
coordinates
))
.
energies
energies_ensemble_jit
=
loaded_compiled_model
((
species
,
coordinates
))
.
energies
energies_single_jit
=
loaded_compiled_model0
((
species
,
coordinates
))
.
energies
print
(
'Ensemble energy, eager mode vs loaded jit:'
,
energies_ensemble
.
item
(),
energies_ensemble_jit
.
item
())
print
(
'Single network energy, eager mode vs loaded jit:'
,
energies_single
.
item
(),
energies_single_jit
.
item
())
examples/load_from_neurochem.py
View file @
004f5a52
...
...
@@ -75,7 +75,7 @@ methane = ase.Atoms('CHHHH', positions=coordinates.squeeze().detach().numpy())
###############################################################################
# Now let's compute energies using the ensemble directly:
_
,
energy
=
nnp1
((
species
,
coordinates
))
energy
=
nnp1
((
species
,
coordinates
))
.
energies
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
force
=
-
derivative
print
(
'Energy:'
,
energy
.
item
())
...
...
@@ -89,7 +89,7 @@ print('Force:', methane.get_forces() / ase.units.Hartree)
###############################################################################
# We can do the same thing with the single model:
_
,
energy
=
nnp2
((
species
,
coordinates
))
energy
=
nnp2
((
species
,
coordinates
))
.
energies
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
force
=
-
derivative
print
(
'Energy:'
,
energy
.
item
())
...
...
examples/nnp_training.py
View file @
004f5a52
...
...
@@ -286,7 +286,7 @@ def validate():
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
...
...
@@ -343,7 +343,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
for
chunk_species
,
chunk_coordinates
in
batch_x
:
num_atoms
.
append
((
chunk_species
>=
0
).
to
(
true_energies
.
dtype
).
sum
(
dim
=
1
))
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
predicted_energies
.
append
(
chunk_energies
)
num_atoms
=
torch
.
cat
(
num_atoms
)
...
...
examples/nnp_training_force.py
View file @
004f5a52
...
...
@@ -231,7 +231,7 @@ def validate():
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
...
...
@@ -299,7 +299,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
# that we could compute force from it
chunk_coordinates
.
requires_grad_
(
True
)
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
# We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to
...
...
examples/vibration_analysis.py
View file @
004f5a52
...
...
@@ -54,7 +54,7 @@ masses = element_masses[species]
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
_
,
energies
=
model
((
species
,
coordinates
))
energies
=
model
((
species
,
coordinates
))
.
energies
###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
...
...
torchani/aev.py
View file @
004f5a52
import
torch
from
torch
import
Tensor
import
math
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
,
NamedTuple
from
torch.jit
import
Final
class
SpeciesAEV
(
NamedTuple
):
species
:
Tensor
aevs
:
Tensor
def
cutoff_cosine
(
distances
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
# assuming all elements in distances are smaller than cutoff
return
0.5
*
torch
.
cos
(
distances
*
(
math
.
pi
/
cutoff
))
+
0.5
...
...
@@ -356,7 +361,7 @@ class AEVComputer(torch.nn.Module):
return
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
pbc
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
]
:
pbc
:
Optional
[
Tensor
]
=
None
)
->
SpeciesAEV
:
"""Compute AEVs
Arguments:
...
...
@@ -384,7 +389,7 @@ class AEVComputer(torch.nn.Module):
for that direction.
Returns:
t
uple: Species and AEVs. species are the species from the input
NamedT
uple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
...
...
@@ -398,4 +403,5 @@ class AEVComputer(torch.nn.Module):
cutoff
=
max
(
self
.
Rcr
,
self
.
Rca
)
shifts
=
compute_shifts
(
cell
,
pbc
,
cutoff
)
return
species
,
compute_aev
(
species
,
coordinates
,
cell
,
shifts
,
self
.
triu_index
,
self
.
constants
(),
self
.
sizes
)
aev
=
compute_aev
(
species
,
coordinates
,
cell
,
shifts
,
self
.
triu_index
,
self
.
constants
(),
self
.
sizes
)
return
SpeciesAEV
(
species
,
aev
)
torchani/ase.py
View file @
004f5a52
...
...
@@ -93,11 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
strain_y
=
self
.
strain
(
cell
,
displacement_y
,
1
)
strain_z
=
self
.
strain
(
cell
,
displacement_z
,
2
)
cell
=
cell
+
strain_x
+
strain_y
+
strain_z
_
,
aev
=
self
.
aev_computer
((
species
,
coordinates
),
cell
=
cell
,
pbc
=
pbc
)
aev
=
self
.
aev_computer
((
species
,
coordinates
),
cell
=
cell
,
pbc
=
pbc
)
.
aevs
else
:
_
,
aev
=
self
.
aev_computer
((
species
,
coordinates
))
aev
=
self
.
aev_computer
((
species
,
coordinates
))
.
aevs
_
,
energy
=
self
.
nn
((
species
,
aev
))
energy
=
self
.
nn
((
species
,
aev
))
.
energies
energy
*=
ase
.
units
.
Hartree
self
.
results
[
'energy'
]
=
energy
.
item
()
self
.
results
[
'free_energy'
]
=
energy
.
item
()
...
...
torchani/nn.py
View file @
004f5a52
import
torch
from
torch
import
Tensor
from
typing
import
Tuple
from
typing
import
Tuple
,
NamedTuple
class
SpeciesEnergies
(
NamedTuple
):
species
:
Tensor
energies
:
Tensor
class
ANIModel
(
torch
.
nn
.
Module
):
...
...
@@ -26,7 +31,7 @@ class ANIModel(torch.nn.Module):
def
__getitem__
(
self
,
i
):
return
self
.
module_list
[
i
]
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
])
->
SpeciesEnergies
:
species
,
aev
=
species_aev
species_
=
species
.
flatten
()
aev
=
aev
.
flatten
(
0
,
1
)
...
...
@@ -40,7 +45,7 @@ class ANIModel(torch.nn.Module):
input_
=
aev
.
index_select
(
0
,
midx
)
output
.
masked_scatter_
(
mask
,
m
(
input_
).
flatten
())
output
=
output
.
view_as
(
species
)
return
species
,
torch
.
sum
(
output
,
dim
=
1
)
return
SpeciesEnergies
(
species
,
torch
.
sum
(
output
,
dim
=
1
)
)
class
Ensemble
(
torch
.
nn
.
Module
):
...
...
@@ -51,12 +56,12 @@ class Ensemble(torch.nn.Module):
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
self
.
size
=
len
(
self
.
modules_list
)
def
forward
(
self
,
species_input
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
species_input
:
Tuple
[
Tensor
,
Tensor
])
->
SpeciesEnergies
:
sum_
=
0
for
x
in
self
.
modules_list
:
sum_
+=
x
(
species_input
)[
1
]
species
,
_
=
species_input
return
species
,
sum_
/
self
.
size
return
SpeciesEnergies
(
species
,
sum_
/
self
.
size
)
def
__getitem__
(
self
,
i
):
return
self
.
modules_list
[
i
]
...
...
@@ -69,7 +74,7 @@ class Sequential(torch.nn.Module):
super
(
Sequential
,
self
).
__init__
()
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
]):
for
module
in
self
.
modules_list
:
input_
=
module
(
input_
)
return
input_
...
...
torchani/utils.py
View file @
004f5a52
...
...
@@ -4,7 +4,8 @@ import torch.utils.data
import
math
import
numpy
as
np
from
collections
import
defaultdict
from
typing
import
Tuple
from
typing
import
Tuple
,
NamedTuple
from
.nn
import
SpeciesEnergies
def
pad
(
species
):
...
...
@@ -211,12 +212,12 @@ class EnergyShifter(torch.nn.Module):
properties
[
'energies'
]
=
energies
return
atomic_properties
,
properties
def
forward
(
self
,
species_energies
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
species_energies
:
Tuple
[
Tensor
,
Tensor
])
->
SpeciesEnergies
:
"""(species, molecular energies)->(species, molecular energies + sae)
"""
species
,
energies
=
species_energies
sae
=
self
.
sae
(
species
).
to
(
energies
.
device
)
return
species
,
energies
.
to
(
sae
.
dtype
)
+
sae
return
SpeciesEnergies
(
species
,
energies
.
to
(
sae
.
dtype
)
+
sae
)
class
ChemicalSymbolsToInts
:
...
...
@@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None):
],
dim
=
1
)
class
FreqsModes
(
NamedTuple
):
freqs
:
Tensor
modes
:
Tensor
def
vibrational_analysis
(
masses
,
hessian
,
unit
=
'cm^-1'
):
"""Computing the vibrational wavenumbers from hessian."""
if
unit
!=
'cm^-1'
:
...
...
@@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
wavenumbers
=
frequencies
*
17092
modes
=
(
eigenvectors
.
t
()
*
inv_sqrt_mass
).
reshape
(
frequencies
.
numel
(),
-
1
,
3
)
return
wavenumbers
,
modes
return
FreqsModes
(
wavenumbers
,
modes
)
__all__
=
[
'pad'
,
'pad_atomic_properties'
,
'present_species'
,
'hessian'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment