Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
379d2b33
Unverified
Commit
379d2b33
authored
Sep 07, 2019
by
Farhad Ramezanghorbani
Committed by
GitHub
Sep 07, 2019
Browse files
[JIT] Add TorchScript Compatibility for EnergyShifter (#306)
* enable EnergyShifter scripting * fix * fix
parent
f2170e24
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
tests/test_energies.py
tests/test_energies.py
+12
-4
torchani/utils.py
torchani/utils.py
+3
-1
No files found.
tests/test_energies.py
View file @
379d2b33
...
@@ -16,10 +16,10 @@ class TestEnergies(unittest.TestCase):
...
@@ -16,10 +16,10 @@ class TestEnergies(unittest.TestCase):
self
.
tolerance
=
5e-5
self
.
tolerance
=
5e-5
ani1x
=
torchani
.
models
.
ANI1x
()
ani1x
=
torchani
.
models
.
ANI1x
()
self
.
aev_computer
=
ani1x
.
aev_computer
self
.
aev_computer
=
ani1x
.
aev_computer
nnp
=
ani1x
.
neural_networks
[
0
]
self
.
nnp
=
ani1x
.
neural_networks
[
0
]
s
hift_
energy
=
ani1x
.
energy_shifter
s
elf
.
energy
_shifter
=
ani1x
.
energy_shifter
self
.
nn
=
torch
.
nn
.
Sequential
(
nnp
,
s
hift_
energy
)
self
.
nn
=
torch
.
nn
.
Sequential
(
self
.
nnp
,
s
elf
.
energy
_shifter
)
self
.
model
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
nnp
,
s
hift_
energy
)
self
.
model
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
nnp
,
s
elf
.
energy
_shifter
)
def
random_skip
(
self
):
def
random_skip
(
self
):
return
False
return
False
...
@@ -116,5 +116,13 @@ class TestEnergies(unittest.TestCase):
...
@@ -116,5 +116,13 @@ class TestEnergies(unittest.TestCase):
self
.
assertLess
(
max_diff
/
math
.
sqrt
(
natoms
),
self
.
tolerance
)
self
.
assertLess
(
max_diff
/
math
.
sqrt
(
natoms
),
self
.
tolerance
)
class
TestEnergiesEnergyShifterJIT
(
TestEnergies
):
def
setUp
(
self
):
super
().
setUp
()
self
.
energy_shifter
=
torch
.
jit
.
script
(
self
.
energy_shifter
)
self
.
nn
=
torch
.
nn
.
Sequential
(
self
.
nnp
,
self
.
energy_shifter
)
self
.
model
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
nnp
,
self
.
energy_shifter
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchani/utils.py
View file @
379d2b33
...
@@ -3,6 +3,7 @@ import torch.utils.data
...
@@ -3,6 +3,7 @@ import torch.utils.data
import
math
import
math
import
numpy
as
np
import
numpy
as
np
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Tuple
def
pad
(
species
):
def
pad
(
species
):
...
@@ -191,7 +192,7 @@ class EnergyShifter(torch.nn.Module):
...
@@ -191,7 +192,7 @@ class EnergyShifter(torch.nn.Module):
intercept
=
self
.
self_energies
[
-
1
]
intercept
=
self
.
self_energies
[
-
1
]
self_energies
=
self
.
self_energies
[
species
]
self_energies
=
self
.
self_energies
[
species
]
self_energies
[
species
==
-
1
]
=
0
self_energies
[
species
==
torch
.
tensor
(
-
1
)
]
=
torch
.
tensor
(
0
)
return
self_energies
.
sum
(
dim
=
1
)
+
intercept
return
self_energies
.
sum
(
dim
=
1
)
+
intercept
def
subtract_from_dataset
(
self
,
atomic_properties
,
properties
):
def
subtract_from_dataset
(
self
,
atomic_properties
,
properties
):
...
@@ -210,6 +211,7 @@ class EnergyShifter(torch.nn.Module):
...
@@ -210,6 +211,7 @@ class EnergyShifter(torch.nn.Module):
return
atomic_properties
,
properties
return
atomic_properties
,
properties
def
forward
(
self
,
species_energies
):
def
forward
(
self
,
species_energies
):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""(species, molecular energies)->(species, molecular energies + sae)
"""(species, molecular energies)->(species, molecular energies + sae)
"""
"""
species
,
energies
=
species_energies
species
,
energies
=
species_energies
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment