Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
629bc698
Unverified
Commit
629bc698
authored
Jul 31, 2018
by
Gao, Xiang
Committed by
GitHub
Jul 31, 2018
Browse files
make aev computer take a single tuple as input (#37)
parent
d07bd02c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
148 additions
and
183 deletions
+148
-183
tests/test_aev.py
tests/test_aev.py
+3
-1
tests/test_benchmark.py
tests/test_benchmark.py
+7
-4
torchani/aev.py
torchani/aev.py
+136
-23
torchani/aev_base.py
torchani/aev_base.py
+0
-131
torchani/models/ani_model.py
torchani/models/ani_model.py
+2
-24
No files found.
tests/test_aev.py
View file @
629bc698
...
@@ -16,7 +16,9 @@ class TestAEV(unittest.TestCase):
...
@@ -16,7 +16,9 @@ class TestAEV(unittest.TestCase):
def
_test_molecule
(
self
,
coordinates
,
species
,
expected_radial
,
def
_test_molecule
(
self
,
coordinates
,
species
,
expected_radial
,
expected_angular
):
expected_angular
):
radial
,
angular
=
self
.
aev
(
coordinates
,
species
)
aev
=
self
.
aev
((
coordinates
,
species
))
radial
=
aev
[...,
:
self
.
aev
.
radial_length
]
angular
=
aev
[...,
self
.
aev
.
radial_length
:]
radial_diff
=
expected_radial
-
radial
radial_diff
=
expected_radial
-
radial
radial_max_error
=
torch
.
max
(
torch
.
abs
(
radial_diff
)).
item
()
radial_max_error
=
torch
.
max
(
torch
.
abs
(
radial_diff
)).
item
()
angular_diff
=
expected_angular
-
angular
angular_diff
=
expected_angular
-
angular
...
...
tests/test_benchmark.py
View file @
629bc698
...
@@ -41,7 +41,10 @@ class TestBenchmark(unittest.TestCase):
...
@@ -41,7 +41,10 @@ class TestBenchmark(unittest.TestCase):
self
.
assertEqual
(
module
.
timers
[
i
],
0
)
self
.
assertEqual
(
module
.
timers
[
i
],
0
)
old_timers
=
copy
.
copy
(
module
.
timers
)
old_timers
=
copy
.
copy
(
module
.
timers
)
for
_
in
range
(
self
.
count
):
for
_
in
range
(
self
.
count
):
module
(
self
.
coordinates
,
self
.
species
)
if
isinstance
(
module
,
torchani
.
aev
.
AEVComputer
):
module
((
self
.
coordinates
,
self
.
species
))
else
:
module
(
self
.
coordinates
,
self
.
species
)
for
i
in
keys
:
for
i
in
keys
:
self
.
assertLess
(
old_timers
[
i
],
module
.
timers
[
i
])
self
.
assertLess
(
old_timers
[
i
],
module
.
timers
[
i
])
for
i
in
asserts
:
for
i
in
asserts
:
...
@@ -90,16 +93,16 @@ class TestBenchmark(unittest.TestCase):
...
@@ -90,16 +93,16 @@ class TestBenchmark(unittest.TestCase):
'total>mask_r'
,
'total>mask_a'
'total>mask_r'
,
'total>mask_a'
])
])
def
testModel
OnAEV
(
self
):
def
test
ANI
Model
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
dtype
,
device
=
self
.
device
)
model
=
torchani
.
models
.
NeuroChemNNP
(
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
,
benchmark
=
True
)
aev_computer
,
benchmark
=
True
)
self
.
_testModule
(
model
,
[
'forward>aev'
,
'forward>nn'
])
self
.
_testModule
(
model
,
[
'forward>nn'
])
model
=
torchani
.
models
.
NeuroChemNNP
(
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
,
benchmark
=
True
,
derivative
=
True
)
aev_computer
,
benchmark
=
True
,
derivative
=
True
)
self
.
_testModule
(
self
.
_testModule
(
model
,
[
'forward>aev'
,
'forward>nn'
,
'forward>derivative'
])
model
,
[
'forward>nn'
,
'forward>derivative'
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
torchani/aev.py
View file @
629bc698
import
torch
import
torch
import
itertools
import
itertools
import
numpy
import
numpy
from
.aev_base
import
AEVComputer
from
.env
import
buildin_const_file
,
default_dtype
,
default_device
from
.env
import
buildin_const_file
,
default_dtype
,
default_device
from
.benchmarked
import
BenchmarkedModule
class
AEVComputer
(
BenchmarkedModule
):
__constants__
=
[
'Rcr'
,
'Rca'
,
'dtype'
,
'device'
,
'radial_sublength'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
"""Base class of various implementations of AEV computer
Attributes
----------
benchmark : boolean
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
Cutoff radius
EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
radial_sublength : int
The length of radial subaev of a single species
radial_length : int
The length of full radial aev
angular_sublength : int
The length of angular subaev of a single species
angular_length : int
The length of full angular aev
aev_length : int
The length of full aev
"""
def
__init__
(
self
,
benchmark
=
False
,
dtype
=
default_dtype
,
device
=
default_device
,
const_file
=
buildin_const_file
):
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
self
.
dtype
=
dtype
self
.
const_file
=
const_file
self
.
device
=
device
# load constants from const file
with
open
(
const_file
)
as
f
:
for
i
in
f
:
try
:
line
=
[
x
.
strip
()
for
x
in
i
.
split
(
'='
)]
name
=
line
[
0
]
value
=
line
[
1
]
if
name
==
'Rcr'
or
name
==
'Rca'
:
setattr
(
self
,
name
,
float
(
value
))
elif
name
in
[
'EtaR'
,
'ShfR'
,
'Zeta'
,
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
value
=
torch
.
tensor
(
value
,
dtype
=
dtype
,
device
=
device
)
setattr
(
self
,
name
,
value
)
elif
name
==
'Atyp'
:
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
self
.
species
=
value
except
Exception
:
raise
ValueError
(
'unable to parse const file'
)
# Compute lengths
self
.
radial_sublength
=
self
.
EtaR
.
shape
[
0
]
*
self
.
ShfR
.
shape
[
0
]
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
angular_sublength
=
self
.
EtaA
.
shape
[
0
]
*
\
self
.
Zeta
.
shape
[
0
]
*
self
.
ShfA
.
shape
[
0
]
*
self
.
ShfZ
.
shape
[
0
]
species
=
len
(
self
.
species
)
self
.
angular_length
=
int
(
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self
.
EtaR
=
self
.
EtaR
.
view
(
-
1
,
1
)
self
.
ShfR
=
self
.
ShfR
.
view
(
1
,
-
1
)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self
.
EtaA
=
self
.
EtaA
.
view
(
-
1
,
1
,
1
,
1
)
self
.
Zeta
=
self
.
Zeta
.
view
(
1
,
-
1
,
1
,
1
)
self
.
ShfA
=
self
.
ShfA
.
view
(
1
,
1
,
-
1
,
1
)
self
.
ShfZ
=
self
.
ShfZ
.
view
(
1
,
1
,
1
,
-
1
)
def
sort_by_species
(
self
,
data
,
species
):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
data : torch.Tensor
Tensor of shape (conformations, atoms, ...) for data.
species : list
List storing species of each atom.
Returns
-------
(torch.Tensor, list)
Tuple of (sorted data, sorted species).
"""
atoms
=
list
(
zip
(
species
,
torch
.
unbind
(
data
,
1
)))
atoms
=
sorted
(
atoms
,
key
=
lambda
x
:
self
.
species
.
index
(
x
[
0
]))
species
=
[
s
for
s
,
_
in
atoms
]
data
=
torch
.
stack
([
c
for
_
,
c
in
atoms
],
dim
=
1
)
return
data
,
species
def
forward
(
self
,
coordinates_species
):
"""Compute AEV from coordinates and species
Parameters
----------
(coordinates, species)
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor
of `dtype`. The radial AEV must be of shape
(conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
"""
raise
NotImplementedError
(
'subclass must override this method'
)
def
_cutoff_cosine
(
distances
,
cutoff
):
def
_cutoff_cosine
(
distances
,
cutoff
):
...
@@ -353,7 +482,8 @@ class SortedAEV(AEVComputer):
...
@@ -353,7 +482,8 @@ class SortedAEV(AEVComputer):
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
def
forward
(
self
,
coordinates
,
species
):
def
forward
(
self
,
coordinates_species
):
coordinates
,
species
=
coordinates_species
species
=
self
.
species_to_tensor
(
species
)
species
=
self
.
species_to_tensor
(
species
)
present_species
=
species
.
unique
(
sorted
=
True
)
present_species
=
species
.
unique
(
sorted
=
True
)
...
@@ -365,24 +495,7 @@ class SortedAEV(AEVComputer):
...
@@ -365,24 +495,7 @@ class SortedAEV(AEVComputer):
species_a
=
species
[
indices_a
]
species_a
=
species
[
indices_a
]
mask_a
=
self
.
compute_mask_a
(
species_a
,
present_species
)
mask_a
=
self
.
compute_mask_a
(
species_a
,
present_species
)
return
self
.
assemble
(
radial_terms
,
angular_terms
,
present_species
,
radial
,
angular
=
self
.
assemble
(
radial_terms
,
angular_terms
,
mask_r
,
mask_a
)
present_species
,
mask_r
,
mask_a
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
def
export_radial_subaev_onnx
(
self
,
filename
):
return
fullaev
"""Export the operation that compute radial subaev into onnx format
Parameters
----------
filename : string
Name of the file to store exported networks.
"""
class
M
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
outerself
):
super
(
M
,
self
).
__init__
()
self
.
outerself
=
outerself
def
forward
(
self
,
center
,
neighbors
):
return
self
.
outerself
.
radial_subaev
(
center
,
neighbors
)
dummy_center
=
torch
.
randn
(
1
,
3
)
dummy_neighbors
=
torch
.
randn
(
1
,
5
,
3
)
torch
.
onnx
.
export
(
M
(
self
),
(
dummy_center
,
dummy_neighbors
),
filename
)
torchani/aev_base.py
deleted
100644 → 0
View file @
d07bd02c
import
torch
from
.env
import
buildin_const_file
,
default_dtype
,
default_device
from
.benchmarked
import
BenchmarkedModule
class
AEVComputer
(
BenchmarkedModule
):
__constants__
=
[
'Rcr'
,
'Rca'
,
'dtype'
,
'device'
,
'radial_sublength'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
"""Base class of various implementations of AEV computer
Attributes
----------
benchmark : boolean
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
Cutoff radius
EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
radial_sublength : int
The length of radial subaev of a single species
radial_length : int
The length of full radial aev
angular_sublength : int
The length of angular subaev of a single species
angular_length : int
The length of full angular aev
aev_length : int
The length of full aev
"""
def
__init__
(
self
,
benchmark
=
False
,
dtype
=
default_dtype
,
device
=
default_device
,
const_file
=
buildin_const_file
):
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
self
.
dtype
=
dtype
self
.
const_file
=
const_file
self
.
device
=
device
# load constants from const file
with
open
(
const_file
)
as
f
:
for
i
in
f
:
try
:
line
=
[
x
.
strip
()
for
x
in
i
.
split
(
'='
)]
name
=
line
[
0
]
value
=
line
[
1
]
if
name
==
'Rcr'
or
name
==
'Rca'
:
setattr
(
self
,
name
,
float
(
value
))
elif
name
in
[
'EtaR'
,
'ShfR'
,
'Zeta'
,
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
value
=
torch
.
tensor
(
value
,
dtype
=
dtype
,
device
=
device
)
setattr
(
self
,
name
,
value
)
elif
name
==
'Atyp'
:
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
self
.
species
=
value
except
Exception
:
raise
ValueError
(
'unable to parse const file'
)
# Compute lengths
self
.
radial_sublength
=
self
.
EtaR
.
shape
[
0
]
*
self
.
ShfR
.
shape
[
0
]
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
angular_sublength
=
self
.
EtaA
.
shape
[
0
]
*
\
self
.
Zeta
.
shape
[
0
]
*
self
.
ShfA
.
shape
[
0
]
*
self
.
ShfZ
.
shape
[
0
]
species
=
len
(
self
.
species
)
self
.
angular_length
=
int
(
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self
.
EtaR
=
self
.
EtaR
.
view
(
-
1
,
1
)
self
.
ShfR
=
self
.
ShfR
.
view
(
1
,
-
1
)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self
.
EtaA
=
self
.
EtaA
.
view
(
-
1
,
1
,
1
,
1
)
self
.
Zeta
=
self
.
Zeta
.
view
(
1
,
-
1
,
1
,
1
)
self
.
ShfA
=
self
.
ShfA
.
view
(
1
,
1
,
-
1
,
1
)
self
.
ShfZ
=
self
.
ShfZ
.
view
(
1
,
1
,
1
,
-
1
)
def
sort_by_species
(
self
,
data
,
species
):
"""Sort the data by its species according to the order in `self.species`
Parameters
----------
data : torch.Tensor
Tensor of shape (conformations, atoms, ...) for data.
species : list
List storing species of each atom.
Returns
-------
(torch.Tensor, list)
Tuple of (sorted data, sorted species).
"""
atoms
=
list
(
zip
(
species
,
torch
.
unbind
(
data
,
1
)))
atoms
=
sorted
(
atoms
,
key
=
lambda
x
:
self
.
species
.
index
(
x
[
0
]))
species
=
[
s
for
s
,
_
in
atoms
]
data
=
torch
.
stack
([
c
for
_
,
c
in
atoms
],
dim
=
1
)
return
data
,
species
def
forward
(
self
,
coordinates
,
species
):
"""Compute AEV from coordinates and species
Parameters
----------
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor
of `dtype`. The radial AEV must be of shape
(conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
"""
raise
NotImplementedError
(
'subclass must override this method'
)
torchani/models/ani_model.py
View file @
629bc698
from
..aev
_base
import
AEVComputer
from
..aev
import
AEVComputer
import
torch
import
torch
from
..benchmarked
import
BenchmarkedModule
from
..benchmarked
import
BenchmarkedModule
...
@@ -67,7 +67,6 @@ class ANIModel(BenchmarkedModule):
...
@@ -67,7 +67,6 @@ class ANIModel(BenchmarkedModule):
'derivative can only be computed for output length 1'
)
'derivative can only be computed for output length 1'
)
if
benchmark
:
if
benchmark
:
self
.
compute_aev
=
self
.
_enable_benchmark
(
self
.
compute_aev
,
'aev'
)
self
.
aev_to_output
=
self
.
_enable_benchmark
(
self
.
aev_to_output
=
self
.
_enable_benchmark
(
self
.
aev_to_output
,
'nn'
)
self
.
aev_to_output
,
'nn'
)
if
derivative
:
if
derivative
:
...
@@ -75,27 +74,6 @@ class ANIModel(BenchmarkedModule):
...
@@ -75,27 +74,6 @@ class ANIModel(BenchmarkedModule):
self
.
compute_derivative
,
'derivative'
)
self
.
compute_derivative
,
'derivative'
)
self
.
forward
=
self
.
_enable_benchmark
(
self
.
forward
,
'forward'
)
self
.
forward
=
self
.
_enable_benchmark
(
self
.
forward
,
'forward'
)
def
compute_aev
(
self
,
coordinates
,
species
):
"""Compute full AEV
Parameters
----------
coordinates : torch.Tensor
The pytorch tensor of shape (conformations, atoms, 3) storing
the coordinates of all atoms of all conformations.
species : list of string
List of string storing the species for each atom.
Returns
-------
torch.Tensor
Pytorch tensor of shape (conformations, atoms, aev_length) storing
the computed AEVs.
"""
radial_aev
,
angular_aev
=
self
.
aev_computer
(
coordinates
,
species
)
fullaev
=
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=
2
)
return
fullaev
def
aev_to_output
(
self
,
aev
,
species
):
def
aev_to_output
(
self
,
aev
,
species
):
"""Compute output from aev
"""Compute output from aev
...
@@ -173,7 +151,7 @@ class ANIModel(BenchmarkedModule):
...
@@ -173,7 +151,7 @@ class ANIModel(BenchmarkedModule):
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
_coordinates
,
_species
=
self
.
aev_computer
.
sort_by_species
(
_coordinates
,
_species
=
self
.
aev_computer
.
sort_by_species
(
coordinates
,
species
)
coordinates
,
species
)
aev
=
self
.
compute
_aev
(
_coordinates
,
_species
)
aev
=
self
.
aev_
compute
r
(
(
_coordinates
,
_species
)
)
output
=
self
.
aev_to_output
(
aev
,
_species
)
output
=
self
.
aev_to_output
(
aev
,
_species
)
if
not
self
.
derivative
:
if
not
self
.
derivative
:
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment