Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
004f5a52
Commit
004f5a52
authored
Nov 08, 2019
by
Gao, Xiang
Committed by
Farhad Ramezanghorbani
Nov 08, 2019
Browse files
Use namedtuple to improve API while still maintaining backward compatibility (#380)
* Use namedtuple to improve API * improve
parent
92c307dc
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
46 additions
and
31 deletions
+46
-31
examples/ase_interface.py
examples/ase_interface.py
+0
-1
examples/energy_force.py
examples/energy_force.py
+1
-2
examples/jit.py
examples/jit.py
+4
-4
examples/load_from_neurochem.py
examples/load_from_neurochem.py
+2
-2
examples/nnp_training.py
examples/nnp_training.py
+2
-2
examples/nnp_training_force.py
examples/nnp_training_force.py
+2
-2
examples/vibration_analysis.py
examples/vibration_analysis.py
+1
-1
torchani/aev.py
torchani/aev.py
+10
-4
torchani/ase.py
torchani/ase.py
+3
-3
torchani/nn.py
torchani/nn.py
+11
-6
torchani/utils.py
torchani/utils.py
+10
-4
No files found.
examples/ase_interface.py
View file @
004f5a52
...
@@ -16,7 +16,6 @@ calculator.
...
@@ -16,7 +16,6 @@ calculator.
###############################################################################
###############################################################################
# To begin with, let's first import the modules we will use:
# To begin with, let's first import the modules we will use:
from
__future__
import
print_function
from
ase.lattice.cubic
import
Diamond
from
ase.lattice.cubic
import
Diamond
from
ase.md.langevin
import
Langevin
from
ase.md.langevin
import
Langevin
from
ase.optimize
import
BFGS
from
ase.optimize
import
BFGS
...
...
examples/energy_force.py
View file @
004f5a52
...
@@ -9,7 +9,6 @@ TorchANI and can be used directly.
...
@@ -9,7 +9,6 @@ TorchANI and can be used directly.
###############################################################################
###############################################################################
# To begin with, let's first import the modules we will use:
# To begin with, let's first import the modules we will use:
from
__future__
import
print_function
import
torch
import
torch
import
torchani
import
torchani
...
@@ -43,7 +42,7 @@ species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
...
@@ -43,7 +42,7 @@ species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
###############################################################################
###############################################################################
# Now let's compute energy and force:
# Now let's compute energy and force:
_
,
energy
=
model
((
species
,
coordinates
))
energy
=
model
((
species
,
coordinates
))
.
energies
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
force
=
-
derivative
force
=
-
derivative
...
...
examples/jit.py
View file @
004f5a52
...
@@ -46,9 +46,9 @@ species = model.species_to_tensor('CHHHH').unsqueeze(0)
...
@@ -46,9 +46,9 @@ species = model.species_to_tensor('CHHHH').unsqueeze(0)
###############################################################################
###############################################################################
# And here is the result:
# And here is the result:
_
,
energies_ensemble
=
model
((
species
,
coordinates
))
energies_ensemble
=
model
((
species
,
coordinates
))
.
energies
_
,
energies_single
=
model
[
0
]((
species
,
coordinates
))
energies_single
=
model
[
0
]((
species
,
coordinates
))
.
energies
_
,
energies_ensemble_jit
=
loaded_compiled_model
((
species
,
coordinates
))
energies_ensemble_jit
=
loaded_compiled_model
((
species
,
coordinates
))
.
energies
_
,
energies_single_jit
=
loaded_compiled_model0
((
species
,
coordinates
))
energies_single_jit
=
loaded_compiled_model0
((
species
,
coordinates
))
.
energies
print
(
'Ensemble energy, eager mode vs loaded jit:'
,
energies_ensemble
.
item
(),
energies_ensemble_jit
.
item
())
print
(
'Ensemble energy, eager mode vs loaded jit:'
,
energies_ensemble
.
item
(),
energies_ensemble_jit
.
item
())
print
(
'Single network energy, eager mode vs loaded jit:'
,
energies_single
.
item
(),
energies_single_jit
.
item
())
print
(
'Single network energy, eager mode vs loaded jit:'
,
energies_single
.
item
(),
energies_single_jit
.
item
())
examples/load_from_neurochem.py
View file @
004f5a52
...
@@ -75,7 +75,7 @@ methane = ase.Atoms('CHHHH', positions=coordinates.squeeze().detach().numpy())
...
@@ -75,7 +75,7 @@ methane = ase.Atoms('CHHHH', positions=coordinates.squeeze().detach().numpy())
###############################################################################
###############################################################################
# Now let's compute energies using the ensemble directly:
# Now let's compute energies using the ensemble directly:
_
,
energy
=
nnp1
((
species
,
coordinates
))
energy
=
nnp1
((
species
,
coordinates
))
.
energies
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
force
=
-
derivative
force
=
-
derivative
print
(
'Energy:'
,
energy
.
item
())
print
(
'Energy:'
,
energy
.
item
())
...
@@ -89,7 +89,7 @@ print('Force:', methane.get_forces() / ase.units.Hartree)
...
@@ -89,7 +89,7 @@ print('Force:', methane.get_forces() / ase.units.Hartree)
###############################################################################
###############################################################################
# We can do the same thing with the single model:
# We can do the same thing with the single model:
_
,
energy
=
nnp2
((
species
,
coordinates
))
energy
=
nnp2
((
species
,
coordinates
))
.
energies
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
force
=
-
derivative
force
=
-
derivative
print
(
'Energy:'
,
energy
.
item
())
print
(
'Energy:'
,
energy
.
item
())
...
...
examples/nnp_training.py
View file @
004f5a52
...
@@ -286,7 +286,7 @@ def validate():
...
@@ -286,7 +286,7 @@ def validate():
true_energies
=
batch_y
[
'energies'
]
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
predicted_energies
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
for
chunk_species
,
chunk_coordinates
in
batch_x
:
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
...
@@ -343,7 +343,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
...
@@ -343,7 +343,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
for
chunk_species
,
chunk_coordinates
in
batch_x
:
for
chunk_species
,
chunk_coordinates
in
batch_x
:
num_atoms
.
append
((
chunk_species
>=
0
).
to
(
true_energies
.
dtype
).
sum
(
dim
=
1
))
num_atoms
.
append
((
chunk_species
>=
0
).
to
(
true_energies
.
dtype
).
sum
(
dim
=
1
))
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
.
append
(
chunk_energies
)
num_atoms
=
torch
.
cat
(
num_atoms
)
num_atoms
=
torch
.
cat
(
num_atoms
)
...
...
examples/nnp_training_force.py
View file @
004f5a52
...
@@ -231,7 +231,7 @@ def validate():
...
@@ -231,7 +231,7 @@ def validate():
true_energies
=
batch_y
[
'energies'
]
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
predicted_energies
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
for
chunk_species
,
chunk_coordinates
in
batch_x
:
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
...
@@ -299,7 +299,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
...
@@ -299,7 +299,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
# that we could compute force from it
# that we could compute force from it
chunk_coordinates
.
requires_grad_
(
True
)
chunk_coordinates
.
requires_grad_
(
True
)
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
.
energies
# We can use torch.autograd.grad to compute force. Remember to
# We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to
# create graph so that the loss of the force can contribute to
...
...
examples/vibration_analysis.py
View file @
004f5a52
...
@@ -54,7 +54,7 @@ masses = element_masses[species]
...
@@ -54,7 +54,7 @@ masses = element_masses[species]
# To do vibration analysis, we first need to generate a graph that computes
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
# is the same as the code to compute energy:
_
,
energies
=
model
((
species
,
coordinates
))
energies
=
model
((
species
,
coordinates
))
.
energies
###############################################################################
###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
# We can now use the energy graph to compute analytical Hessian matrix:
...
...
torchani/aev.py
View file @
004f5a52
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
math
import
math
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
,
NamedTuple
from
torch.jit
import
Final
from
torch.jit
import
Final
class
SpeciesAEV
(
NamedTuple
):
species
:
Tensor
aevs
:
Tensor
def
cutoff_cosine
(
distances
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
def
cutoff_cosine
(
distances
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
# assuming all elements in distances are smaller than cutoff
# assuming all elements in distances are smaller than cutoff
return
0.5
*
torch
.
cos
(
distances
*
(
math
.
pi
/
cutoff
))
+
0.5
return
0.5
*
torch
.
cos
(
distances
*
(
math
.
pi
/
cutoff
))
+
0.5
...
@@ -356,7 +361,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -356,7 +361,7 @@ class AEVComputer(torch.nn.Module):
return
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
return
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
pbc
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
]
:
pbc
:
Optional
[
Tensor
]
=
None
)
->
SpeciesAEV
:
"""Compute AEVs
"""Compute AEVs
Arguments:
Arguments:
...
@@ -384,7 +389,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -384,7 +389,7 @@ class AEVComputer(torch.nn.Module):
for that direction.
for that direction.
Returns:
Returns:
t
uple: Species and AEVs. species are the species from the input
NamedT
uple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
``(C, A, self.aev_length())``
"""
"""
...
@@ -398,4 +403,5 @@ class AEVComputer(torch.nn.Module):
...
@@ -398,4 +403,5 @@ class AEVComputer(torch.nn.Module):
cutoff
=
max
(
self
.
Rcr
,
self
.
Rca
)
cutoff
=
max
(
self
.
Rcr
,
self
.
Rca
)
shifts
=
compute_shifts
(
cell
,
pbc
,
cutoff
)
shifts
=
compute_shifts
(
cell
,
pbc
,
cutoff
)
return
species
,
compute_aev
(
species
,
coordinates
,
cell
,
shifts
,
self
.
triu_index
,
self
.
constants
(),
self
.
sizes
)
aev
=
compute_aev
(
species
,
coordinates
,
cell
,
shifts
,
self
.
triu_index
,
self
.
constants
(),
self
.
sizes
)
return
SpeciesAEV
(
species
,
aev
)
torchani/ase.py
View file @
004f5a52
...
@@ -93,11 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
...
@@ -93,11 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
strain_y
=
self
.
strain
(
cell
,
displacement_y
,
1
)
strain_y
=
self
.
strain
(
cell
,
displacement_y
,
1
)
strain_z
=
self
.
strain
(
cell
,
displacement_z
,
2
)
strain_z
=
self
.
strain
(
cell
,
displacement_z
,
2
)
cell
=
cell
+
strain_x
+
strain_y
+
strain_z
cell
=
cell
+
strain_x
+
strain_y
+
strain_z
_
,
aev
=
self
.
aev_computer
((
species
,
coordinates
),
cell
=
cell
,
pbc
=
pbc
)
aev
=
self
.
aev_computer
((
species
,
coordinates
),
cell
=
cell
,
pbc
=
pbc
)
.
aevs
else
:
else
:
_
,
aev
=
self
.
aev_computer
((
species
,
coordinates
))
aev
=
self
.
aev_computer
((
species
,
coordinates
))
.
aevs
_
,
energy
=
self
.
nn
((
species
,
aev
))
energy
=
self
.
nn
((
species
,
aev
))
.
energies
energy
*=
ase
.
units
.
Hartree
energy
*=
ase
.
units
.
Hartree
self
.
results
[
'energy'
]
=
energy
.
item
()
self
.
results
[
'energy'
]
=
energy
.
item
()
self
.
results
[
'free_energy'
]
=
energy
.
item
()
self
.
results
[
'free_energy'
]
=
energy
.
item
()
...
...
torchani/nn.py
View file @
004f5a52
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
Tuple
from
typing
import
Tuple
,
NamedTuple
class
SpeciesEnergies
(
NamedTuple
):
species
:
Tensor
energies
:
Tensor
class
ANIModel
(
torch
.
nn
.
Module
):
class
ANIModel
(
torch
.
nn
.
Module
):
...
@@ -26,7 +31,7 @@ class ANIModel(torch.nn.Module):
...
@@ -26,7 +31,7 @@ class ANIModel(torch.nn.Module):
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
return
self
.
module_list
[
i
]
return
self
.
module_list
[
i
]
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
])
->
SpeciesEnergies
:
species
,
aev
=
species_aev
species
,
aev
=
species_aev
species_
=
species
.
flatten
()
species_
=
species
.
flatten
()
aev
=
aev
.
flatten
(
0
,
1
)
aev
=
aev
.
flatten
(
0
,
1
)
...
@@ -40,7 +45,7 @@ class ANIModel(torch.nn.Module):
...
@@ -40,7 +45,7 @@ class ANIModel(torch.nn.Module):
input_
=
aev
.
index_select
(
0
,
midx
)
input_
=
aev
.
index_select
(
0
,
midx
)
output
.
masked_scatter_
(
mask
,
m
(
input_
).
flatten
())
output
.
masked_scatter_
(
mask
,
m
(
input_
).
flatten
())
output
=
output
.
view_as
(
species
)
output
=
output
.
view_as
(
species
)
return
species
,
torch
.
sum
(
output
,
dim
=
1
)
return
SpeciesEnergies
(
species
,
torch
.
sum
(
output
,
dim
=
1
)
)
class
Ensemble
(
torch
.
nn
.
Module
):
class
Ensemble
(
torch
.
nn
.
Module
):
...
@@ -51,12 +56,12 @@ class Ensemble(torch.nn.Module):
...
@@ -51,12 +56,12 @@ class Ensemble(torch.nn.Module):
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
self
.
size
=
len
(
self
.
modules_list
)
self
.
size
=
len
(
self
.
modules_list
)
def
forward
(
self
,
species_input
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
species_input
:
Tuple
[
Tensor
,
Tensor
])
->
SpeciesEnergies
:
sum_
=
0
sum_
=
0
for
x
in
self
.
modules_list
:
for
x
in
self
.
modules_list
:
sum_
+=
x
(
species_input
)[
1
]
sum_
+=
x
(
species_input
)[
1
]
species
,
_
=
species_input
species
,
_
=
species_input
return
species
,
sum_
/
self
.
size
return
SpeciesEnergies
(
species
,
sum_
/
self
.
size
)
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
return
self
.
modules_list
[
i
]
return
self
.
modules_list
[
i
]
...
@@ -69,7 +74,7 @@ class Sequential(torch.nn.Module):
...
@@ -69,7 +74,7 @@ class Sequential(torch.nn.Module):
super
(
Sequential
,
self
).
__init__
()
super
(
Sequential
,
self
).
__init__
()
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
]):
for
module
in
self
.
modules_list
:
for
module
in
self
.
modules_list
:
input_
=
module
(
input_
)
input_
=
module
(
input_
)
return
input_
return
input_
...
...
torchani/utils.py
View file @
004f5a52
...
@@ -4,7 +4,8 @@ import torch.utils.data
...
@@ -4,7 +4,8 @@ import torch.utils.data
import
math
import
math
import
numpy
as
np
import
numpy
as
np
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Tuple
from
typing
import
Tuple
,
NamedTuple
from
.nn
import
SpeciesEnergies
def
pad
(
species
):
def
pad
(
species
):
...
@@ -211,12 +212,12 @@ class EnergyShifter(torch.nn.Module):
...
@@ -211,12 +212,12 @@ class EnergyShifter(torch.nn.Module):
properties
[
'energies'
]
=
energies
properties
[
'energies'
]
=
energies
return
atomic_properties
,
properties
return
atomic_properties
,
properties
def
forward
(
self
,
species_energies
:
Tuple
[
Tensor
,
Tensor
])
->
Tuple
[
Tensor
,
Tensor
]
:
def
forward
(
self
,
species_energies
:
Tuple
[
Tensor
,
Tensor
])
->
SpeciesEnergies
:
"""(species, molecular energies)->(species, molecular energies + sae)
"""(species, molecular energies)->(species, molecular energies + sae)
"""
"""
species
,
energies
=
species_energies
species
,
energies
=
species_energies
sae
=
self
.
sae
(
species
).
to
(
energies
.
device
)
sae
=
self
.
sae
(
species
).
to
(
energies
.
device
)
return
species
,
energies
.
to
(
sae
.
dtype
)
+
sae
return
SpeciesEnergies
(
species
,
energies
.
to
(
sae
.
dtype
)
+
sae
)
class
ChemicalSymbolsToInts
:
class
ChemicalSymbolsToInts
:
...
@@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None):
...
@@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None):
],
dim
=
1
)
],
dim
=
1
)
class
FreqsModes
(
NamedTuple
):
freqs
:
Tensor
modes
:
Tensor
def
vibrational_analysis
(
masses
,
hessian
,
unit
=
'cm^-1'
):
def
vibrational_analysis
(
masses
,
hessian
,
unit
=
'cm^-1'
):
"""Computing the vibrational wavenumbers from hessian."""
"""Computing the vibrational wavenumbers from hessian."""
if
unit
!=
'cm^-1'
:
if
unit
!=
'cm^-1'
:
...
@@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
...
@@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
wavenumbers
=
frequencies
*
17092
wavenumbers
=
frequencies
*
17092
modes
=
(
eigenvectors
.
t
()
*
inv_sqrt_mass
).
reshape
(
frequencies
.
numel
(),
-
1
,
3
)
modes
=
(
eigenvectors
.
t
()
*
inv_sqrt_mass
).
reshape
(
frequencies
.
numel
(),
-
1
,
3
)
return
wavenumbers
,
modes
return
FreqsModes
(
wavenumbers
,
modes
)
__all__
=
[
'pad'
,
'pad_atomic_properties'
,
'present_species'
,
'hessian'
,
__all__
=
[
'pad'
,
'pad_atomic_properties'
,
'present_species'
,
'hessian'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment