Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
55e6d4f0
Unverified
Commit
55e6d4f0
authored
Nov 12, 2020
by
Gao, Xiang
Committed by
GitHub
Nov 12, 2020
Browse files
Revert "Use PyTorch autograd's hessian (#532)" (#534)
This reverts commit
bd9d888a
.
parent
30f4ec4e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
71 additions
and
17 deletions
+71
-17
docs/api.rst
docs/api.rst
+1
-0
examples/jit.py
examples/jit.py
+14
-8
examples/vibration_analysis.py
examples/vibration_analysis.py
+10
-4
tests/test_utils.py
tests/test_utils.py
+4
-0
tests/test_vibrational.py
tests/test_vibrational.py
+2
-1
torchani/utils.py
torchani/utils.py
+40
-4
No files found.
docs/api.rst
View file @
55e6d4f0
...
@@ -41,6 +41,7 @@ Utilities
...
@@ -41,6 +41,7 @@ Utilities
.. autofunction:: torchani.utils.map2central
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
:members:
.. autofunction:: torchani.utils.hessian
.. autofunction:: torchani.utils.vibrational_analysis
.. autofunction:: torchani.utils.vibrational_analysis
.. autofunction:: torchani.utils.get_atomic_masses
.. autofunction:: torchani.utils.get_atomic_masses
...
...
examples/jit.py
View file @
55e6d4f0
...
@@ -69,7 +69,7 @@ print('Single network energy, eager mode vs loaded jit:', energies_single.item()
...
@@ -69,7 +69,7 @@ print('Single network energy, eager mode vs loaded jit:', energies_single.item()
#
#
# - uses double as dtype instead of float
# - uses double as dtype instead of float
# - don't care about periodic boundary condition
# - don't care about periodic boundary condition
# - in addition to energies, allow returning optionally forces
# - in addition to energies, allow return
s
ing optionally forces
, and hessians
# - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ...
# - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ...
#
#
# you could do the following:
# you could do the following:
...
@@ -81,28 +81,34 @@ class CustomModule(torch.nn.Module):
...
@@ -81,28 +81,34 @@ class CustomModule(torch.nn.Module):
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()
def
forward
(
self
,
species
:
Tensor
,
coordinates
:
Tensor
,
return_forces
:
bool
=
False
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
def
forward
(
self
,
species
:
Tensor
,
coordinates
:
Tensor
,
return_forces
:
bool
=
False
,
if
return_forces
:
return_hessians
:
bool
=
False
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
],
Optional
[
Tensor
]]:
if
return_forces
or
return_hessians
:
coordinates
.
requires_grad_
(
True
)
coordinates
.
requires_grad_
(
True
)
energies
=
self
.
model
((
species
,
coordinates
)).
energies
energies
=
self
.
model
((
species
,
coordinates
)).
energies
forces
:
Optional
[
Tensor
]
=
None
# noqa: E701
forces
:
Optional
[
Tensor
]
=
None
# noqa: E701
if
return_forces
:
hessians
:
Optional
[
Tensor
]
=
None
grad
=
torch
.
autograd
.
grad
([
energies
.
sum
()],
[
coordinates
])[
0
]
if
return_forces
or
return_hessians
:
grad
=
torch
.
autograd
.
grad
([
energies
.
sum
()],
[
coordinates
],
create_graph
=
return_hessians
)[
0
]
assert
grad
is
not
None
assert
grad
is
not
None
forces
=
-
grad
forces
=
-
grad
return
energies
,
forces
if
return_hessians
:
hessians
=
torchani
.
utils
.
hessian
(
coordinates
,
forces
=
forces
)
return
energies
,
forces
,
hessians
custom_model
=
CustomModule
()
custom_model
=
CustomModule
()
compiled_custom_model
=
torch
.
jit
.
script
(
custom_model
)
compiled_custom_model
=
torch
.
jit
.
script
(
custom_model
)
torch
.
jit
.
save
(
compiled_custom_model
,
'compiled_custom_model.pt'
)
torch
.
jit
.
save
(
compiled_custom_model
,
'compiled_custom_model.pt'
)
loaded_compiled_custom_model
=
torch
.
jit
.
load
(
'compiled_custom_model.pt'
)
loaded_compiled_custom_model
=
torch
.
jit
.
load
(
'compiled_custom_model.pt'
)
energies
,
forces
=
custom_model
(
species
,
coordinates
,
True
)
energies
,
forces
,
hessians
=
custom_model
(
species
,
coordinates
,
True
,
True
)
energies_jit
,
forces_jit
=
loaded_compiled_custom_model
(
species
,
coordinates
,
True
)
energies_jit
,
forces_jit
,
hessians_jit
=
loaded_compiled_custom_model
(
species
,
coordinates
,
True
,
True
)
print
(
'Energy, eager mode vs loaded jit:'
,
energies
.
item
(),
energies_jit
.
item
())
print
(
'Energy, eager mode vs loaded jit:'
,
energies
.
item
(),
energies_jit
.
item
())
print
()
print
()
print
(
'Force, eager mode vs loaded jit:
\n
'
,
forces
.
squeeze
(
0
),
'
\n
'
,
forces_jit
.
squeeze
(
0
))
print
(
'Force, eager mode vs loaded jit:
\n
'
,
forces
.
squeeze
(
0
),
'
\n
'
,
forces_jit
.
squeeze
(
0
))
print
()
print
()
torch
.
set_printoptions
(
sci_mode
=
False
,
linewidth
=
1000
)
print
(
'Hessian, eager mode vs loaded jit:
\n
'
,
hessians
.
squeeze
(
0
),
'
\n
'
,
hessians_jit
.
squeeze
(
0
))
examples/vibration_analysis.py
View file @
55e6d4f0
...
@@ -47,12 +47,18 @@ coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_g
...
@@ -47,12 +47,18 @@ coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_g
masses
=
torchani
.
utils
.
get_atomic_masses
(
species
)
masses
=
torchani
.
utils
.
get_atomic_masses
(
species
)
###############################################################################
###############################################################################
# We can use :func:`torch.autograd.functional.hessian` to compute hessian:
# To do vibration analysis, we first need to generate a graph that computes
hessian
=
torch
.
autograd
.
functional
.
hessian
(
lambda
x
:
model
((
species
,
x
)).
energies
,
coordinates
)
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
energies
=
model
((
species
,
coordinates
)).
energies
###############################################################################
###############################################################################
# The Hessian matrix should have shape `(1, 3, 3, 1, 3, 3)`, where 1 means there
# We can now use the energy graph to compute analytical Hessian matrix:
# is only one molecule to compute, 3 means 3 atoms and 3D space.
hessian
=
torchani
.
utils
.
hessian
(
coordinates
,
energies
=
energies
)
###############################################################################
# The Hessian matrix should have shape `(1, 9, 9)`, where 1 means there is only
# one molecule to compute, 9 means `3 atoms * 3D space = 9 degree of freedom`.
print
(
hessian
.
shape
)
print
(
hessian
.
shape
)
###############################################################################
###############################################################################
...
...
tests/test_utils.py
View file @
55e6d4f0
import
unittest
import
unittest
import
torch
import
torchani
import
torchani
...
@@ -9,6 +10,9 @@ class TestUtils(unittest.TestCase):
...
@@ -9,6 +10,9 @@ class TestUtils(unittest.TestCase):
self
.
assertEqual
(
len
(
str2i
),
6
)
self
.
assertEqual
(
len
(
str2i
),
6
)
self
.
assertListEqual
(
str2i
(
'BACCC'
).
tolist
(),
[
1
,
0
,
2
,
2
,
2
])
self
.
assertListEqual
(
str2i
(
'BACCC'
).
tolist
(),
[
1
,
0
,
2
,
2
,
2
])
def
testHessianJIT
(
self
):
torch
.
jit
.
script
(
torchani
.
utils
.
hessian
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
tests/test_vibrational.py
View file @
55e6d4f0
...
@@ -39,7 +39,8 @@ class TestVibrational(unittest.TestCase):
...
@@ -39,7 +39,8 @@ class TestVibrational(unittest.TestCase):
# compute vibrational by torchani
# compute vibrational by torchani
species
=
model
.
species_to_tensor
(
molecule
.
get_chemical_symbols
()).
unsqueeze
(
0
)
species
=
model
.
species_to_tensor
(
molecule
.
get_chemical_symbols
()).
unsqueeze
(
0
)
coordinates
=
torch
.
from_numpy
(
molecule
.
get_positions
()).
unsqueeze
(
0
).
requires_grad_
(
True
)
coordinates
=
torch
.
from_numpy
(
molecule
.
get_positions
()).
unsqueeze
(
0
).
requires_grad_
(
True
)
hessian
=
torch
.
autograd
.
functional
.
hessian
(
lambda
x
:
model
((
species
,
x
)).
energies
,
coordinates
)
_
,
energies
=
model
((
species
,
coordinates
))
hessian
=
torchani
.
utils
.
hessian
(
coordinates
,
energies
=
energies
)
freq2
,
modes2
,
_
,
_
=
torchani
.
utils
.
vibrational_analysis
(
masses
[
species
],
hessian
)
freq2
,
modes2
,
_
,
_
=
torchani
.
utils
.
vibrational_analysis
(
masses
[
species
],
hessian
)
freq2
=
freq2
[
6
:].
float
()
freq2
=
freq2
[
6
:].
float
()
modes2
=
modes2
[
6
:]
modes2
=
modes2
[
6
:]
...
...
torchani/utils.py
View file @
55e6d4f0
...
@@ -240,6 +240,43 @@ class ChemicalSymbolsToInts:
...
@@ -240,6 +240,43 @@ class ChemicalSymbolsToInts:
return
len
(
self
.
rev_species
)
return
len
(
self
.
rev_species
)
def
_get_derivatives_not_none
(
x
:
Tensor
,
y
:
Tensor
,
retain_graph
:
Optional
[
bool
]
=
None
,
create_graph
:
bool
=
False
)
->
Tensor
:
ret
=
torch
.
autograd
.
grad
([
y
.
sum
()],
[
x
],
retain_graph
=
retain_graph
,
create_graph
=
create_graph
)[
0
]
assert
ret
is
not
None
return
ret
def
hessian
(
coordinates
:
Tensor
,
energies
:
Optional
[
Tensor
]
=
None
,
forces
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
"""Compute analytical hessian from the energy graph or force graph.
Arguments:
coordinates (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`
energies (:class:`torch.Tensor`): Tensor of shape `(molecules,)`, if specified,
then `forces` must be `None`. This energies must be computed from
`coordinates` in a graph.
forces (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`, if specified,
then `energies` must be `None`. This forces must be computed from
`coordinates` in a graph.
Returns:
:class:`torch.Tensor`: Tensor of shape `(molecules, 3A, 3A)` where A is the number of
atoms in each molecule
"""
if
energies
is
None
and
forces
is
None
:
raise
ValueError
(
'Energies or forces must be specified'
)
if
energies
is
not
None
and
forces
is
not
None
:
raise
ValueError
(
'Energies or forces can not be specified at the same time'
)
if
forces
is
None
:
assert
energies
is
not
None
forces
=
-
_get_derivatives_not_none
(
coordinates
,
energies
,
create_graph
=
True
)
flattened_force
=
forces
.
flatten
(
start_dim
=
1
)
force_components
=
flattened_force
.
unbind
(
dim
=
1
)
return
-
torch
.
stack
([
_get_derivatives_not_none
(
coordinates
,
f
,
retain_graph
=
True
).
flatten
(
start_dim
=
1
)
for
f
in
force_components
],
dim
=
1
)
class
FreqsModes
(
NamedTuple
):
class
FreqsModes
(
NamedTuple
):
freqs
:
Tensor
freqs
:
Tensor
modes
:
Tensor
modes
:
Tensor
...
@@ -279,8 +316,6 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'):
...
@@ -279,8 +316,6 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'):
raise
ValueError
(
'Only meV and cm^-1 are supported right now'
)
raise
ValueError
(
'Only meV and cm^-1 are supported right now'
)
assert
hessian
.
shape
[
0
]
==
1
,
'Currently only supporting computing one molecule a time'
assert
hessian
.
shape
[
0
]
==
1
,
'Currently only supporting computing one molecule a time'
degree_of_freedom
=
hessian
.
shape
[
1
]
*
hessian
.
shape
[
2
]
hessian
=
hessian
.
reshape
(
1
,
degree_of_freedom
,
degree_of_freedom
)
# Solving the eigenvalue problem: Hq = w^2 * T q
# Solving the eigenvalue problem: Hq = w^2 * T q
# where H is the Hessian matrix, q is the normal coordinates,
# where H is the Hessian matrix, q is the normal coordinates,
# T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass
# T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass
...
@@ -390,5 +425,6 @@ PERIODIC_TABLE = ['Dummy'] + """
...
@@ -390,5 +425,6 @@ PERIODIC_TABLE = ['Dummy'] + """
"""
.
strip
().
split
()
"""
.
strip
().
split
()
__all__
=
[
'pad_atomic_properties'
,
'present_species'
,
'vibrational_analysis'
,
__all__
=
[
'pad_atomic_properties'
,
'present_species'
,
'hessian'
,
'strip_redundant_padding'
,
'ChemicalSymbolsToInts'
,
'get_atomic_masses'
]
'vibrational_analysis'
,
'strip_redundant_padding'
,
'ChemicalSymbolsToInts'
,
'get_atomic_masses'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment