Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
35531421
Unverified
Commit
35531421
authored
Apr 06, 2019
by
Gao, Xiang
Committed by
GitHub
Apr 06, 2019
Browse files
Completely rewrite AEVComputer (#197)
parent
bc5f4312
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
469 additions
and
516 deletions
+469
-516
docs/api.rst
docs/api.rst
+1
-2
tests/test_aev.py
tests/test_aev.py
+148
-11
tests/test_ase.py
tests/test_ase.py
+3
-195
tests/test_data.py
tests/test_data.py
+1
-1
tests/test_energies.py
tests/test_energies.py
+0
-1
tests/test_forces.py
tests/test_forces.py
+0
-1
tests/test_ignite.py
tests/test_ignite.py
+2
-1
torchani/aev.py
torchani/aev.py
+267
-207
torchani/ase.py
torchani/ase.py
+9
-93
torchani/utils.py
torchani/utils.py
+38
-4
No files found.
docs/api.rst
View file @
35531421
...
...
@@ -37,6 +37,7 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
...
...
@@ -61,8 +62,6 @@ ASE Interface
=============
.. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator
Ignite Helpers
...
...
tests/test_aev.py
View file @
35531421
...
...
@@ -3,7 +3,9 @@ import torchani
import
unittest
import
os
import
pickle
import
random
import
itertools
import
ase
import
math
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
N
=
97
...
...
@@ -93,19 +95,154 @@ class TestAEV(unittest.TestCase):
self
.
_assertAEVEqual
(
radial
,
angular
,
aev
)
class
Test
AEVASENeighborLi
st
(
Test
AEV
):
class
Test
PBCSeeEachOther
(
unitte
st
.
Test
Case
):
def
setUp
(
self
):
super
(
TestAEVASENeighborList
,
self
).
setUp
()
self
.
aev_computer
.
neighborlist
=
torchani
.
ase
.
NeighborList
()
self
.
builtin
=
torchani
.
neurochem
.
Builtins
()
self
.
aev_computer
=
self
.
builtin
.
aev_computer
.
to
(
torch
.
double
)
def
testTranslationalInvariancePBC
(
self
):
coordinates
=
torch
.
tensor
(
[[[
0
,
0
,
0
],
[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
],
[
0
,
1
,
1
]]],
dtype
=
torch
.
double
,
requires_grad
=
True
)
cell
=
torch
.
eye
(
3
,
dtype
=
torch
.
double
)
*
2
species
=
torch
.
tensor
([[
1
,
0
,
0
,
0
,
0
]],
dtype
=
torch
.
long
)
pbc
=
torch
.
ones
(
3
,
dtype
=
torch
.
uint8
)
_
,
aev
=
self
.
aev_computer
((
species
,
coordinates
,
cell
,
pbc
))
for
_
in
range
(
100
):
translation
=
torch
.
randn
(
3
,
dtype
=
torch
.
double
)
_
,
aev2
=
self
.
aev_computer
((
species
,
coordinates
+
translation
,
cell
,
pbc
))
self
.
assertTrue
(
torch
.
allclose
(
aev
,
aev2
))
def
testPBCConnersSeeEachOther
(
self
):
species
=
torch
.
tensor
([[
0
,
0
]])
cell
=
torch
.
eye
(
3
,
dtype
=
torch
.
double
)
*
10
pbc
=
torch
.
ones
(
3
,
dtype
=
torch
.
uint8
)
allshifts
=
torchani
.
aev
.
compute_shifts
(
cell
,
pbc
,
1
)
xyz1
=
torch
.
tensor
([
0.1
,
0.1
,
0.1
])
xyz2s
=
[
torch
.
tensor
([
9.9
,
0.0
,
0.0
]),
torch
.
tensor
([
0.0
,
9.9
,
0.0
]),
torch
.
tensor
([
0.0
,
0.0
,
9.9
]),
torch
.
tensor
([
9.9
,
9.9
,
0.0
]),
torch
.
tensor
([
0.0
,
9.9
,
9.9
]),
torch
.
tensor
([
9.9
,
0.0
,
9.9
]),
torch
.
tensor
([
9.9
,
9.9
,
9.9
]),
]
for
xyz2
in
xyz2s
:
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
to
(
torch
.
double
).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
def
testPBCSurfaceSeeEachOther
(
self
):
cell
=
torch
.
eye
(
3
,
dtype
=
torch
.
double
)
*
10
pbc
=
torch
.
ones
(
3
,
dtype
=
torch
.
uint8
)
allshifts
=
torchani
.
aev
.
compute_shifts
(
cell
,
pbc
,
1
)
species
=
torch
.
tensor
([[
0
,
0
]])
for
i
in
range
(
3
):
xyz1
=
torch
.
tensor
([
5.0
,
5.0
,
5.0
],
dtype
=
torch
.
double
)
xyz1
[
i
]
=
0.1
xyz2
=
xyz1
.
clone
()
xyz2
[
i
]
=
9.9
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
def
testPBCEdgesSeeEachOther
(
self
):
cell
=
torch
.
eye
(
3
,
dtype
=
torch
.
double
)
*
10
pbc
=
torch
.
ones
(
3
,
dtype
=
torch
.
uint8
)
allshifts
=
torchani
.
aev
.
compute_shifts
(
cell
,
pbc
,
1
)
species
=
torch
.
tensor
([[
0
,
0
]])
for
i
,
j
in
itertools
.
combinations
(
range
(
3
),
2
):
xyz1
=
torch
.
tensor
([
5.0
,
5.0
,
5.0
],
dtype
=
torch
.
double
)
xyz1
[
i
]
=
0.1
xyz1
[
j
]
=
0.1
for
new_i
,
new_j
in
[[
0.1
,
9.9
],
[
9.9
,
0.1
],
[
9.9
,
9.9
]]:
xyz2
=
xyz1
.
clone
()
xyz2
[
i
]
=
new_i
xyz2
[
j
]
=
new_i
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
def
testNonRectangularPBCConnersSeeEachOther
(
self
):
species
=
torch
.
tensor
([[
0
,
0
]])
cell
=
ase
.
geometry
.
cellpar_to_cell
([
10
,
10
,
10
*
math
.
sqrt
(
2
),
90
,
45
,
90
])
cell
=
torch
.
tensor
(
ase
.
geometry
.
complete_cell
(
cell
),
dtype
=
torch
.
double
)
pbc
=
torch
.
ones
(
3
,
dtype
=
torch
.
uint8
)
allshifts
=
torchani
.
aev
.
compute_shifts
(
cell
,
pbc
,
1
)
xyz1
=
torch
.
tensor
([
0.1
,
0.1
,
0.05
],
dtype
=
torch
.
double
)
xyz2
=
torch
.
tensor
([
10.0
,
0.1
,
0.1
],
dtype
=
torch
.
double
)
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
class
TestAEVOnBoundary
(
unittest
.
TestCase
):
def
transform
(
self
,
x
):
"""To reduce the size of test cases for faster test speed"""
return
x
[:
2
,
...]
def
random_skip
(
self
):
"""To reduce the size of test cases for faster test speed"""
return
random
.
random
()
<
0.95
def
setUp
(
self
):
self
.
eps
=
1e-9
cell
=
ase
.
geometry
.
cellpar_to_cell
([
100
,
100
,
100
*
math
.
sqrt
(
2
),
90
,
45
,
90
])
self
.
cell
=
torch
.
tensor
(
ase
.
geometry
.
complete_cell
(
cell
),
dtype
=
torch
.
double
)
self
.
inv_cell
=
torch
.
inverse
(
self
.
cell
)
self
.
coordinates
=
torch
.
tensor
([[[
0.0
,
0.0
,
0.0
],
[
1.0
,
-
0.1
,
-
0.1
],
[
-
0.1
,
1.0
,
-
0.1
],
[
-
0.1
,
-
0.1
,
1.0
],
[
-
1.0
,
-
1.0
,
-
1.0
]]],
dtype
=
torch
.
double
)
self
.
species
=
torch
.
tensor
([[
1
,
0
,
0
,
0
]])
self
.
pbc
=
torch
.
ones
(
3
,
dtype
=
torch
.
uint8
)
self
.
v1
,
self
.
v2
,
self
.
v3
=
self
.
cell
self
.
center_coordinates
=
self
.
coordinates
+
0.5
*
(
self
.
v1
+
self
.
v2
+
self
.
v3
)
builtin
=
torchani
.
neurochem
.
Builtins
()
self
.
aev_computer
=
builtin
.
aev_computer
.
to
(
torch
.
double
)
_
,
self
.
aev
=
self
.
aev_computer
((
self
.
species
,
self
.
center_coordinates
,
self
.
cell
,
self
.
pbc
))
def
assertInCell
(
self
,
coordinates
):
coordinates_cell
=
coordinates
@
self
.
inv_cell
self
.
assertTrue
(
torch
.
allclose
(
coordinates
,
coordinates_cell
@
self
.
cell
))
in_cell
=
(
coordinates_cell
>=
-
self
.
eps
)
&
(
coordinates_cell
<=
1
+
self
.
eps
)
self
.
assertTrue
(
in_cell
.
all
())
def
assertNotInCell
(
self
,
coordinates
):
coordinates_cell
=
coordinates
@
self
.
inv_cell
self
.
assertTrue
(
torch
.
allclose
(
coordinates
,
coordinates_cell
@
self
.
cell
))
in_cell
=
(
coordinates_cell
>=
-
self
.
eps
)
&
(
coordinates_cell
<=
1
+
self
.
eps
)
self
.
assertFalse
(
in_cell
.
all
())
def
testCornerSurfaceAndEdge
(
self
):
for
i
,
j
,
k
in
itertools
.
product
([
0
,
0.5
,
1
],
repeat
=
3
):
if
i
==
0.5
and
j
==
0.5
and
k
==
0.5
:
continue
coordinates
=
self
.
coordinates
+
i
*
self
.
v1
+
j
*
self
.
v2
+
k
*
self
.
v3
self
.
assertNotInCell
(
coordinates
)
coordinates
=
torchani
.
utils
.
map2central
(
self
.
cell
,
coordinates
,
self
.
pbc
)
self
.
assertInCell
(
coordinates
)
_
,
aev
=
self
.
aev_computer
((
self
.
species
,
coordinates
,
self
.
cell
,
self
.
pbc
))
self
.
assertGreater
(
aev
.
abs
().
max
().
item
(),
0
)
self
.
assertTrue
(
torch
.
allclose
(
aev
,
self
.
aev
))
if
__name__
==
'__main__'
:
...
...
tests/test_ase.py
View file @
35531421
from
ase.lattice.cubic
import
Diamond
from
ase.md.langevin
import
Langevin
from
ase
import
units
,
Atoms
from
ase
import
units
from
ase.calculators.test
import
numeric_force
import
torch
import
torchani
import
unittest
import
numpy
import
itertools
import
math
import
os
import
pickle
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
N
=
97
...
...
@@ -26,8 +22,8 @@ def get_numeric_force(atoms, eps):
class
TestASE
(
unittest
.
TestCase
):
def
_
test
Force
(
self
,
pbc
):
atoms
=
Diamond
(
symbol
=
"C"
,
pbc
=
pbc
)
def
test
WithNumericalForceWithPBCEnabled
(
self
):
atoms
=
Diamond
(
symbol
=
"C"
,
pbc
=
True
)
builtin
=
torchani
.
neurochem
.
Builtins
()
calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
...
...
@@ -42,194 +38,6 @@ class TestASE(unittest.TestCase):
if
avgf
>
0
:
self
.
assertLess
(
df
/
avgf
,
0.1
)
def
testForceWithPBCEnabled
(
self
):
self
.
_testForce
(
True
)
def
testForceWithPBCDisabled
(
self
):
self
.
_testForce
(
False
)
def
testANIDataset
(
self
):
builtin
=
torchani
.
neurochem
.
Builtins
()
calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
)
default_neighborlist_calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
,
_default_neighborlist
=
True
)
nnp
=
torch
.
nn
.
Sequential
(
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
)
for
i
in
range
(
N
):
datafile
=
os
.
path
.
join
(
path
,
'test_data/ANI1_subset/{}'
.
format
(
i
))
with
open
(
datafile
,
'rb'
)
as
f
:
coordinates
,
species
,
_
,
_
,
_
,
_
=
pickle
.
load
(
f
)
coordinates
=
coordinates
[
0
]
species
=
species
[
0
]
species_str
=
[
builtin
.
consts
.
species
[
i
]
for
i
in
species
]
atoms
=
Atoms
(
species_str
,
positions
=
coordinates
)
atoms
.
set_calculator
(
calculator
)
energy1
=
atoms
.
get_potential_energy
()
/
units
.
Hartree
forces1
=
atoms
.
get_forces
()
/
units
.
Hartree
atoms2
=
Atoms
(
species_str
,
positions
=
coordinates
)
atoms2
.
set_calculator
(
default_neighborlist_calculator
)
energy2
=
atoms2
.
get_potential_energy
()
/
units
.
Hartree
forces2
=
atoms2
.
get_forces
()
/
units
.
Hartree
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
).
unsqueeze
(
0
)
_
,
energy3
=
nnp
((
torch
.
from_numpy
(
species
).
unsqueeze
(
0
),
coordinates
))
forces3
=
-
torch
.
autograd
.
grad
(
energy3
.
squeeze
(),
coordinates
)[
0
].
numpy
()
energy3
=
energy3
.
item
()
self
.
assertLess
(
abs
(
energy1
-
energy2
),
tol
)
self
.
assertLess
(
abs
(
energy1
-
energy3
),
tol
)
diff_f12
=
torch
.
tensor
(
forces1
-
forces2
).
abs
().
max
().
item
()
self
.
assertLess
(
diff_f12
,
tol
)
diff_f13
=
torch
.
tensor
(
forces1
-
forces3
).
abs
().
max
().
item
()
self
.
assertLess
(
diff_f13
,
tol
)
def
testForceAgainstDefaultNeighborList
(
self
):
atoms
=
Diamond
(
symbol
=
"C"
,
pbc
=
False
)
builtin
=
torchani
.
neurochem
.
Builtins
()
calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
)
default_neighborlist_calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
,
_default_neighborlist
=
True
)
atoms
.
set_calculator
(
calculator
)
dyn
=
Langevin
(
atoms
,
5
*
units
.
fs
,
50
*
units
.
kB
,
0.002
)
def
test_energy
(
a
=
atoms
):
a
=
a
.
copy
()
a
.
set_calculator
(
calculator
)
e1
=
a
.
get_potential_energy
()
a
.
set_calculator
(
default_neighborlist_calculator
)
e2
=
a
.
get_potential_energy
()
self
.
assertLess
(
abs
(
e1
-
e2
),
tol
)
dyn
.
attach
(
test_energy
,
interval
=
1
)
dyn
.
run
(
500
)
def
testTranslationalInvariancePBC
(
self
):
atoms
=
Atoms
(
'CH4'
,
[[
0
,
0
,
0
],
[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
],
[
0
,
1
,
1
]],
cell
=
[
2
,
2
,
2
],
pbc
=
True
)
builtin
=
torchani
.
neurochem
.
Builtins
()
calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
)
atoms
.
set_calculator
(
calculator
)
e
=
atoms
.
get_potential_energy
()
for
_
in
range
(
100
):
positions
=
atoms
.
get_positions
()
translation
=
(
numpy
.
random
.
rand
(
3
)
-
0.5
)
*
2
atoms
.
set_positions
(
positions
+
translation
)
self
.
assertEqual
(
e
,
atoms
.
get_potential_energy
())
def
assertTensorEqual
(
self
,
a
,
b
):
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1e-6
)
def
testPBCConnersSeeEachOther
(
self
):
species
=
torch
.
tensor
([[
0
,
0
]])
neighborlist
=
torchani
.
ase
.
NeighborList
(
cell
=
[
10
,
10
,
10
],
pbc
=
True
)
xyz1
=
torch
.
tensor
([
0.1
,
0.1
,
0.1
])
xyz2s
=
[
torch
.
tensor
([
9.9
,
0.0
,
0.0
]),
torch
.
tensor
([
0.0
,
9.9
,
0.0
]),
torch
.
tensor
([
0.0
,
0.0
,
9.9
]),
torch
.
tensor
([
9.9
,
9.9
,
0.0
]),
torch
.
tensor
([
0.0
,
9.9
,
9.9
]),
torch
.
tensor
([
9.9
,
0.0
,
9.9
]),
torch
.
tensor
([
9.9
,
9.9
,
9.9
]),
]
for
xyz2
in
xyz2s
:
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
s
,
_
,
D
=
neighborlist
(
species
,
coordinates
,
1
)
self
.
assertListEqual
(
list
(
s
.
shape
),
[
1
,
2
,
1
])
neighbor_coordinate
=
D
[
0
][
0
].
squeeze
()
+
xyz1
mirror
=
xyz2
for
i
in
range
(
3
):
if
mirror
[
i
]
>
5
:
mirror
[
i
]
-=
10
self
.
assertTensorEqual
(
neighbor_coordinate
,
mirror
)
def
testPBCSurfaceSeeEachOther
(
self
):
species
=
torch
.
tensor
([[
0
,
0
]])
neighborlist
=
torchani
.
ase
.
NeighborList
(
cell
=
[
10
,
10
,
10
],
pbc
=
True
)
for
i
in
range
(
3
):
xyz1
=
torch
.
tensor
([
5.0
,
5.0
,
5.0
])
xyz1
[
i
]
=
0.1
xyz2
=
xyz1
.
clone
()
xyz2
[
i
]
=
9.9
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
s
,
_
,
D
=
neighborlist
(
species
,
coordinates
,
1
)
self
.
assertListEqual
(
list
(
s
.
shape
),
[
1
,
2
,
1
])
neighbor_coordinate
=
D
[
0
][
0
].
squeeze
()
+
xyz1
xyz2
[
i
]
=
-
0.1
self
.
assertTensorEqual
(
neighbor_coordinate
,
xyz2
)
def
testPBCEdgesSeeEachOther
(
self
):
species
=
torch
.
tensor
([[
0
,
0
]])
neighborlist
=
torchani
.
ase
.
NeighborList
(
cell
=
[
10
,
10
,
10
],
pbc
=
True
)
for
i
,
j
in
itertools
.
combinations
(
range
(
3
),
2
):
xyz1
=
torch
.
tensor
([
5.0
,
5.0
,
5.0
])
xyz1
[
i
]
=
0.1
xyz1
[
j
]
=
0.1
for
new_i
,
new_j
in
[[
0.1
,
9.9
],
[
9.9
,
0.1
],
[
9.9
,
9.9
]]:
xyz2
=
xyz1
.
clone
()
xyz2
[
i
]
=
new_i
xyz2
[
j
]
=
new_i
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
s
,
_
,
D
=
neighborlist
(
species
,
coordinates
,
1
)
self
.
assertListEqual
(
list
(
s
.
shape
),
[
1
,
2
,
1
])
neighbor_coordinate
=
D
[
0
][
0
].
squeeze
()
+
xyz1
if
xyz2
[
i
]
>
5
:
xyz2
[
i
]
=
-
0.1
if
xyz2
[
j
]
>
5
:
xyz2
[
j
]
=
-
0.1
self
.
assertTensorEqual
(
neighbor_coordinate
,
xyz2
)
def
testNonRectangularPBCConnersSeeEachOther
(
self
):
species
=
torch
.
tensor
([[
0
,
0
]])
neighborlist
=
torchani
.
ase
.
NeighborList
(
cell
=
[
10
,
10
,
10
*
math
.
sqrt
(
2
),
90
,
45
,
90
],
pbc
=
True
)
xyz1
=
torch
.
tensor
([
0.1
,
0.1
,
0.05
])
xyz2
=
torch
.
tensor
([
10.0
,
0.1
,
0.1
])
mirror
=
torch
.
tensor
([
0.0
,
0.1
,
0.1
])
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
s
,
_
,
D
=
neighborlist
(
species
,
coordinates
,
1
)
self
.
assertListEqual
(
list
(
s
.
shape
),
[
1
,
2
,
1
])
neighbor_coordinate
=
D
[
0
][
0
].
squeeze
()
+
xyz1
for
i
in
range
(
3
):
if
mirror
[
i
]
>
5
:
mirror
[
i
]
-=
10
self
.
assertTensorEqual
(
neighbor_coordinate
,
mirror
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_data.py
View file @
35531421
...
...
@@ -21,7 +21,7 @@ class TestData(unittest.TestCase):
batch_size
)
def
_assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assert
Equal
((
t1
-
t2
).
abs
().
max
().
item
(),
0
)
self
.
assert
Less
((
t1
-
t2
).
abs
().
max
().
item
(),
1e-6
)
def
testSplitBatch
(
self
):
species1
=
torch
.
randint
(
4
,
(
5
,
4
),
dtype
=
torch
.
long
)
...
...
tests/test_energies.py
View file @
35531421
...
...
@@ -84,7 +84,6 @@ class TestEnergiesASEComputer(TestEnergies):
def
setUp
(
self
):
super
(
TestEnergiesASEComputer
,
self
).
setUp
()
self
.
aev_computer
.
neighborlist
=
torchani
.
ase
.
NeighborList
()
def
transform
(
self
,
x
):
"""To reduce the size of test cases for faster test speed"""
...
...
tests/test_forces.py
View file @
35531421
...
...
@@ -90,7 +90,6 @@ class TestForceASEComputer(TestForce):
def
setUp
(
self
):
super
(
TestForceASEComputer
,
self
).
setUp
()
self
.
aev_computer
.
neighborlist
=
torchani
.
ase
.
NeighborList
()
def
transform
(
self
,
x
):
"""To reduce the size of test cases for faster test speed"""
...
...
tests/test_ignite.py
View file @
35531421
...
...
@@ -22,7 +22,8 @@ class TestIgnite(unittest.TestCase):
shift_energy
=
builtins
.
energy_shifter
ds
=
torchani
.
data
.
BatchedANIDataset
(
path
,
builtins
.
consts
.
species_to_tensor
,
batchsize
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
],
device
=
aev_computer
.
EtaR
.
device
)
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
class
Flatten
(
torch
.
nn
.
Module
):
...
...
torchani/aev.py
View file @
35531421
...
...
@@ -2,13 +2,12 @@ from __future__ import division
import
torch
from
.
import
_six
# noqa:F401
import
math
from
.
import
utils
from
torch
import
Tensor
from
typing
import
Tuple
@
torch
.
jit
.
script
def
_
cutoff_cosine
(
distances
,
cutoff
):
#
@torch.jit.script
def
cutoff_cosine
(
distances
,
cutoff
):
# type: (Tensor, float) -> Tensor
return
torch
.
where
(
distances
<=
cutoff
,
...
...
@@ -17,8 +16,8 @@ def _cutoff_cosine(distances, cutoff):
)
@
torch
.
jit
.
script
def
_
radial_
subaev_
terms
(
Rcr
,
EtaR
,
ShfR
,
distances
):
#
@torch.jit.script
def
radial_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
):
# type: (float, Tensor, Tensor, Tensor) -> Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
...
...
@@ -33,7 +32,7 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances
=
distances
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
fc
=
_
cutoff_cosine
(
distances
,
Rcr
)
fc
=
cutoff_cosine
(
distances
,
Rcr
)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
...
...
@@ -45,8 +44,8 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
return
ret
.
flatten
(
start_dim
=-
2
)
@
torch
.
jit
.
script
def
_
angular_
subaev_
terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vectors1
,
vectors2
):
#
@torch.jit.script
def
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vectors1
,
vectors2
):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
...
...
@@ -60,22 +59,18 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1
=
vectors1
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
vectors2
=
vectors2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
vectors1
=
vectors1
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
vectors2
=
vectors2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
distances1
=
vectors1
.
norm
(
2
,
dim
=-
5
)
distances2
=
vectors2
.
norm
(
2
,
dim
=-
5
)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles
=
0.95
*
\
torch
.
nn
.
functional
.
cosine_similarity
(
vectors1
,
vectors2
,
dim
=-
5
)
cos_angles
=
0.95
*
torch
.
nn
.
functional
.
cosine_similarity
(
vectors1
,
vectors2
,
dim
=-
5
)
angles
=
torch
.
acos
(
cos_angles
)
fcj1
=
_
cutoff_cosine
(
distances1
,
Rca
)
fcj2
=
_
cutoff_cosine
(
distances2
,
Rca
)
fcj1
=
cutoff_cosine
(
distances1
,
Rca
)
fcj2
=
cutoff_cosine
(
distances2
,
Rca
)
factor1
=
((
1
+
torch
.
cos
(
angles
-
ShfZ
))
/
2
)
**
Zeta
factor2
=
torch
.
exp
(
-
EtaA
*
((
distances1
+
distances2
)
/
2
-
ShfA
)
**
2
)
ret
=
2
*
factor1
*
factor2
*
fcj1
*
fcj2
...
...
@@ -86,168 +81,231 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return
ret
.
flatten
(
start_dim
=-
4
)
@
torch
.
jit
.
script
def
_combinations
(
tensor
,
dim
=
0
):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
@
torch
.
jit
.
script
def
_terms_and_indices
(
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
distances
,
vec
):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
# type: (float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
radial_terms
=
_radial_subaev_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
)
vec
=
_combinations
(
vec
,
-
2
)
angular_terms
=
_angular_subaev_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
*
vec
)
return
radial_terms
,
angular_terms
# @torch.jit.script
def
default_neighborlist
(
species
,
coordinates
,
cutoff
):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
"""Default neighborlist computer"""
vec
=
coordinates
.
unsqueeze
(
2
)
-
coordinates
.
unsqueeze
(
1
)
# vec has hape (conformations, atoms, atoms, 3) storing Rij vectors
distances
=
vec
.
norm
(
2
,
-
1
)
# distances has shape (conformations, atoms, atoms) storing Rij distances
padding_mask
=
(
species
==
-
1
).
unsqueeze
(
1
)
distances
=
distances
.
masked_fill
(
padding_mask
,
math
.
inf
)
distances
,
indices
=
distances
.
sort
(
-
1
)
min_distances
,
_
=
distances
.
flatten
(
end_dim
=
1
).
min
(
0
)
in_cutoff
=
(
min_distances
<=
cutoff
).
nonzero
().
flatten
()[
1
:]
indices
=
indices
.
index_select
(
-
1
,
in_cutoff
)
# TODO: remove this workaround after gather support broadcasting
atoms
=
coordinates
.
shape
[
1
]
species_
=
species
.
unsqueeze
(
1
).
expand
(
-
1
,
atoms
,
-
1
)
neighbor_species
=
species_
.
gather
(
-
1
,
indices
)
neighbor_distances
=
distances
.
index_select
(
-
1
,
in_cutoff
)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
3
)
neighbor_coordinates
=
vec
.
gather
(
-
2
,
indices_
)
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
def
compute_shifts
(
cell
,
pbc
,
cutoff
):
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
@
torch
.
jit
.
script
def
_compute_mask_r
(
species_r
,
num_species
):
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
num_species
,
dtype
=
torch
.
long
,
device
=
species_r
.
device
))
return
mask_r
@
torch
.
jit
.
script
def
_compute_mask_a
(
species_a
,
present_species
):
"""Get mask of angular terms for each supported species from indices"""
species_a1
,
species_a2
=
_combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask
=
mask_a1
&
mask_a2
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
mask
|
mask_rev
return
mask_a
Arguments:
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
vectors defining unit cell:
tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
cutoff (float): the cutoff inside which atoms are considered pairs
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
if pbc is enabled for that direction.
Returns:
:class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included.
"""
# type: (Tensor, Tensor, float) -> Tensor
reciprocal_cell
=
cell
.
inverse
().
t
()
inv_distances
=
reciprocal_cell
.
norm
(
2
,
-
1
)
num_repeats
=
torch
.
ceil
(
cutoff
*
inv_distances
).
to
(
torch
.
long
)
num_repeats
=
torch
.
where
(
pbc
,
num_repeats
,
torch
.
zeros_like
(
num_repeats
))
r1
=
torch
.
arange
(
1
,
num_repeats
[
0
]
+
1
,
device
=
cell
.
device
)
r2
=
torch
.
arange
(
1
,
num_repeats
[
1
]
+
1
,
device
=
cell
.
device
)
r3
=
torch
.
arange
(
1
,
num_repeats
[
2
]
+
1
,
device
=
cell
.
device
)
o
=
torch
.
zeros
(
1
,
dtype
=
torch
.
long
,
device
=
cell
.
device
)
return
torch
.
cat
([
torch
.
cartesian_prod
(
r1
,
r2
,
r3
),
torch
.
cartesian_prod
(
r1
,
r2
,
o
),
torch
.
cartesian_prod
(
r1
,
r2
,
-
r3
),
torch
.
cartesian_prod
(
r1
,
o
,
r3
),
torch
.
cartesian_prod
(
r1
,
o
,
o
),
torch
.
cartesian_prod
(
r1
,
o
,
-
r3
),
torch
.
cartesian_prod
(
r1
,
-
r2
,
r3
),
torch
.
cartesian_prod
(
r1
,
-
r2
,
o
),
torch
.
cartesian_prod
(
r1
,
-
r2
,
-
r3
),
torch
.
cartesian_prod
(
o
,
r2
,
r3
),
torch
.
cartesian_prod
(
o
,
r2
,
o
),
torch
.
cartesian_prod
(
o
,
r2
,
-
r3
),
torch
.
cartesian_prod
(
o
,
o
,
r3
),
])
@
torch
.
jit
.
script
def
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
num_species
,
angular_sublength
):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
# @torch.jit.script
def
neighbor_pairs
(
padding_mask
,
coordinates
,
cell
,
shifts
,
cutoff
):
"""Compute pairs of atoms that are neighbors
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
padding_mask (:class:`torch.Tensor`): boolean tensor of shape
(molecules, atoms) for padding mask. 1 == is padding.
coordinates (:class:`torch.Tensor`): tensor of shape
(molecules, atoms, 3) for atom coordinates.
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
# type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor]
coordinates
=
coordinates
.
detach
()
cell
=
cell
.
detach
()
num_atoms
=
padding_mask
.
shape
[
1
]
all_atoms
=
torch
.
arange
(
num_atoms
,
device
=
cell
.
device
)
# Step 2: center cell
p1_center
,
p2_center
=
torch
.
combinations
(
all_atoms
).
unbind
(
-
1
)
shifts_center
=
shifts
.
new_zeros
(
p1_center
.
shape
[
0
],
3
)
# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
num_shifts
=
shifts
.
shape
[
0
]
all_shifts
=
torch
.
arange
(
num_shifts
,
device
=
cell
.
device
)
shift_index
,
p1
,
p2
=
torch
.
cartesian_prod
(
all_shifts
,
all_atoms
,
all_atoms
).
unbind
(
-
1
)
shifts_outide
=
shifts
.
index_select
(
0
,
shift_index
)
# Step 4: combine results for all cells
shifts_all
=
torch
.
cat
([
shifts_center
,
shifts_outide
])
p1_all
=
torch
.
cat
([
p1_center
,
p1
])
p2_all
=
torch
.
cat
([
p2_center
,
p2
])
shift_values
=
torch
.
mm
(
shifts_all
.
to
(
cell
.
dtype
),
cell
)
# step 5, compute distances, and find all pairs within cutoff
distances
=
(
coordinates
.
index_select
(
1
,
p1_all
)
-
coordinates
.
index_select
(
1
,
p2_all
)
+
shift_values
).
norm
(
2
,
-
1
)
padding_mask
=
(
padding_mask
.
index_select
(
1
,
p1_all
))
|
(
padding_mask
.
index_select
(
1
,
p2_all
))
distances
.
masked_fill_
(
padding_mask
,
math
.
inf
)
in_cutoff
=
(
distances
<=
cutoff
).
nonzero
()
molecule_index
,
pair_index
=
in_cutoff
.
unbind
(
1
)
atom_index1
=
p1_all
[
pair_index
]
atom_index2
=
p2_all
[
pair_index
]
shifts
=
shifts_all
.
index_select
(
0
,
pair_index
)
return
molecule_index
,
atom_index1
,
atom_index2
,
shifts
# torch.jit.script
def
triu_index
(
num_species
):
species
=
torch
.
arange
(
num_species
)
species1
,
species2
=
torch
.
combinations
(
species
,
r
=
2
,
with_replacement
=
True
).
unbind
(
-
1
)
pair_index
=
torch
.
arange
(
species1
.
shape
[
0
])
ret
=
torch
.
zeros
(
num_species
,
num_species
,
dtype
=
torch
.
long
)
ret
[
species1
,
species2
]
=
pair_index
ret
[
species2
,
species1
]
=
pair_index
return
ret
# torch.jit.script
def
convert_pair_index
(
index
):
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
"""
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, int, int) -> Tuple[Tensor, Tensor] # noqa: E501
conformations
=
radial_terms
.
shape
[
0
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
to
(
radial_terms
.
dtype
)).
sum
(
-
3
)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
# assemble angular subaev
rev_indices
=
torch
.
full
((
num_species
,),
-
1
,
dtype
=
present_species
.
dtype
,
device
=
present_species
.
device
)
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
dtype
=
torch
.
long
,
device
=
radial_terms
.
device
)
angular_aevs
=
[]
zero_angular_subaev
=
torch
.
zeros
(
conformations
,
atoms
,
angular_sublength
,
dtype
=
radial_terms
.
dtype
,
device
=
radial_terms
.
device
)
for
s1
in
range
(
num_species
):
# TODO: make PyTorch support range(start, end) and
# range(start, end, step) and remove the workaround
# below. The inner for loop should be:
# for s2 in range(s1, num_species):
for
s2
in
range
(
num_species
-
s1
):
s2
+=
s1
i1
=
int
(
rev_indices
[
s1
])
i2
=
int
(
rev_indices
[
s2
])
if
i1
>=
0
and
i2
>=
0
:
mask
=
mask_a
[:,
:,
:,
i1
,
i2
].
unsqueeze
(
-
1
)
\
.
to
(
radial_terms
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
angular_aevs
.
append
(
subaev
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
@
torch
.
jit
.
script
def
_compute_aev
(
num_species
,
angular_sublength
,
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
species
,
species_
,
distances
,
vec
):
# type: (int, int, float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
present_species
=
utils
.
present_species
(
species
)
radial_terms
,
angular_terms
=
_terms_and_indices
(
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
distances
,
vec
)
mask_r
=
_compute_mask_r
(
species_
,
num_species
)
mask_a
=
_compute_mask_a
(
species_
,
present_species
)
radial
,
angular
=
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
num_species
,
angular_sublength
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
return
species
,
fullaev
n
=
(
torch
.
sqrt
(
1.0
+
8.0
*
index
.
to
(
torch
.
float
))
-
1.0
)
/
2.0
n
=
torch
.
floor
(
n
).
to
(
torch
.
long
)
num_elems
=
n
*
(
n
+
1
)
/
2
return
index
-
num_elems
,
n
+
1
# torch.jit.script
def
cumsum_from_zero
(
input_
):
cumsum
=
torch
.
cumsum
(
input_
,
dim
=
0
)
cumsum
=
torch
.
cat
([
input_
.
new_tensor
([
0
]),
cumsum
[:
-
1
]])
return
cumsum
# torch.jit.script
def
triple_by_molecule
(
molecule_index
,
atom_index1
,
atom_index2
):
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
Output: indices for all central atoms and it pairs of neighbors. For
example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
(1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
# convert representation from pair to central-other
n
=
molecule_index
.
shape
[
0
]
mi
=
molecule_index
.
repeat
(
2
)
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
# sort and compute unique key
mi_ai1
=
torch
.
stack
([
mi
,
ai1
],
dim
=
1
)
m_ac
,
rev_indices
,
counts
=
torch
.
_unique_dim2_temporary_will_remove_soon
(
mi_ai1
,
dim
=
0
,
sorted
=
True
,
return_inverse
=
True
,
return_counts
=
True
)
uniqued_molecule_index
,
uniqued_central_atom_index
=
m_ac
.
unbind
(
1
)
# do local combinations within unique key, assuming sorted
pair_sizes
=
counts
*
(
counts
-
1
)
//
2
total_size
=
pair_sizes
.
sum
()
molecule_index
=
torch
.
repeat_interleave
(
uniqued_molecule_index
,
pair_sizes
)
central_atom_index
=
torch
.
repeat_interleave
(
uniqued_central_atom_index
,
pair_sizes
)
cumsum
=
cumsum_from_zero
(
pair_sizes
)
cumsum
=
torch
.
repeat_interleave
(
cumsum
,
pair_sizes
)
sorted_local_pair_index
=
torch
.
arange
(
total_size
,
device
=
molecule_index
.
device
)
-
cumsum
sorted_local_index1
,
sorted_local_index2
=
convert_pair_index
(
sorted_local_pair_index
)
cumsum
=
cumsum_from_zero
(
counts
)
cumsum
=
torch
.
repeat_interleave
(
cumsum
,
pair_sizes
)
sorted_local_index1
+=
cumsum
sorted_local_index2
+=
cumsum
# unsort result from last part
argsort
=
rev_indices
.
argsort
()
local_index1
=
argsort
[
sorted_local_index1
]
local_index2
=
argsort
[
sorted_local_index2
]
# compute mapping between representation of central-other to pair
sign1
=
torch
.
where
(
local_index1
<
n
,
torch
.
ones_like
(
local_index1
),
-
torch
.
ones_like
(
local_index1
))
sign2
=
torch
.
where
(
local_index2
<
n
,
torch
.
ones_like
(
local_index2
),
-
torch
.
ones_like
(
local_index2
))
pair_index1
=
torch
.
where
(
local_index1
<
n
,
local_index1
,
local_index1
-
n
)
pair_index2
=
torch
.
where
(
local_index2
<
n
,
local_index2
,
local_index2
-
n
)
return
molecule_index
,
central_atom_index
,
pair_index1
,
pair_index2
,
sign1
,
sign2
# torch.jit.script
def
compute_aev
(
species
,
coordinates
,
cell
,
pbc_switch
,
triu_index
,
constants
,
sizes
):
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
=
constants
num_species
,
radial_sublength
,
radial_length
,
angular_sublength
,
angular_length
,
aev_length
=
sizes
num_molecules
=
species
.
shape
[
0
]
num_atoms
=
species
.
shape
[
1
]
num_species_pairs
=
angular_length
//
angular_sublength
cutoff
=
max
(
Rcr
,
Rca
)
shifts
=
compute_shifts
(
cell
,
pbc_switch
,
cutoff
)
molecule_index
,
atom_index1
,
atom_index2
,
shifts
=
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
shifts
,
cutoff
)
species1
=
species
[
molecule_index
,
atom_index1
]
species2
=
species
[
molecule_index
,
atom_index2
]
shift_values
=
torch
.
mm
(
shifts
.
to
(
cell
.
dtype
),
cell
)
vec
=
coordinates
[
molecule_index
,
atom_index1
,
:]
-
coordinates
[
molecule_index
,
atom_index2
,
:]
+
shift_values
distances
=
vec
.
norm
(
2
,
-
1
)
# compute radial aev
radial_terms_
=
radial_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
)
radial_aev
=
radial_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species
,
radial_sublength
)
index1
=
(
molecule_index
*
num_atoms
+
atom_index1
)
*
num_species
+
species2
index2
=
(
molecule_index
*
num_atoms
+
atom_index2
)
*
num_species
+
species1
radial_aev
.
scatter_add_
(
0
,
index1
.
unsqueeze
(
1
).
expand
(
-
1
,
radial_sublength
),
radial_terms_
)
radial_aev
.
scatter_add_
(
0
,
index2
.
unsqueeze
(
1
).
expand
(
-
1
,
radial_sublength
),
radial_terms_
)
radial_aev
=
radial_aev
.
reshape
(
num_molecules
,
num_atoms
,
radial_length
)
# compute angular aev
molecule_index
,
central_atom_index
,
pair_index1
,
pair_index2
,
sign1
,
sign2
=
triple_by_molecule
(
molecule_index
,
atom_index1
,
atom_index2
)
vec1
=
vec
.
index_select
(
0
,
pair_index1
)
*
sign1
.
unsqueeze
(
1
).
to
(
vec
.
dtype
)
vec2
=
vec
.
index_select
(
0
,
pair_index2
)
*
sign2
.
unsqueeze
(
1
).
to
(
vec
.
dtype
)
species1_
=
torch
.
where
(
sign1
==
1
,
species2
[
pair_index1
],
species1
[
pair_index1
])
species2_
=
torch
.
where
(
sign2
==
1
,
species2
[
pair_index2
],
species1
[
pair_index2
])
angular_terms_
=
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vec1
,
vec2
)
angular_aev
=
angular_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species_pairs
,
angular_sublength
)
index
=
(
molecule_index
*
num_atoms
+
central_atom_index
)
*
num_species_pairs
+
triu_index
[
species1_
,
species2_
]
angular_aev
.
scatter_add_
(
0
,
index
.
unsqueeze
(
1
).
expand
(
-
1
,
angular_sublength
),
angular_terms_
)
angular_aev
=
angular_aev
.
reshape
(
num_molecules
,
num_atoms
,
angular_length
)
return
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=-
1
)
class
AEVComputer
(
torch
.
nn
.
Module
):
...
...
@@ -271,20 +329,6 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
...
...
@@ -293,8 +337,7 @@ class AEVComputer(torch.nn.Module):
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
,
neighborlist_computer
=
default_neighborlist
):
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
super
(
AEVComputer
,
self
).
__init__
()
self
.
Rcr
=
Rcr
self
.
Rca
=
Rca
...
...
@@ -309,43 +352,60 @@ class AEVComputer(torch.nn.Module):
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
num_species
=
num_species
self
.
neighborlist
=
neighborlist_computer
# The length of radial subaev of a single species
self
.
radial_sublength
=
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
# The length of full radial aev
self
.
radial_length
=
self
.
num_species
*
self
.
radial_sublength
# The length of angular subaev of a single species
self
.
angular_sublength
=
self
.
EtaA
.
numel
()
*
self
.
Zeta
.
numel
()
*
\
self
.
ShfA
.
numel
()
*
self
.
ShfZ
.
numel
()
self
.
angular_sublength
=
self
.
EtaA
.
numel
()
*
self
.
Zeta
.
numel
()
*
self
.
ShfA
.
numel
()
*
self
.
ShfZ
.
numel
()
# The length of full angular aev
self
.
angular_length
=
(
self
.
num_species
*
(
self
.
num_species
+
1
))
\
//
2
*
self
.
angular_sublength
self
.
angular_length
=
(
self
.
num_species
*
(
self
.
num_species
+
1
))
//
2
*
self
.
angular_sublength
# The length of full aev
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
self
.
sizes
=
self
.
num_species
,
self
.
radial_sublength
,
self
.
radial_length
,
self
.
angular_sublength
,
self
.
angular_length
,
self
.
aev_length
self
.
register_buffer
(
'triu_index'
,
triu_index
(
num_species
))
def
constants
(
self
):
return
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
# @torch.jit.script_method
def
forward
(
self
,
species_coordinates
):
def
forward
(
self
,
input
):
"""Compute AEVs
Arguments:
species_coordinates (tuple): Two tensors: species and coordinates.
input (tuple): Can be one of the following two cases:
If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species and coordinates.
species must have shape ``(C, A)`` and coordinates must have
shape ``(C, A, 3)``, where ``C`` is the number of
conformation
s
shape ``(C, A, 3)``, where ``C`` is the number of
molecule
s
in a chunk, and ``A`` is the number of atoms.
If you want to apply periodic boundary conditions, then the input
would be a tuple of four tensors: species, coordinates, cell, pbc
where species and coordinates are the same as described above, cell
is a tensor of shape (3, 3) of the three vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
and pbc is boolean vector of size 3 storing if pbc is enabled
for that direction.
Returns:
tuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
species
,
coordinates
=
species_coordinates
max_cutoff
=
max
(
self
.
Rcr
,
self
.
Rca
)
species_
,
distances
,
vec
=
self
.
neighborlist
(
species
,
coordinates
,
max_cutoff
)
return
_compute_aev
(
self
.
num_species
,
self
.
angular_sublength
,
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
,
species
,
species_
,
distances
,
vec
)
if
len
(
input
)
==
2
:
species
,
coordinates
=
input
cell
=
torch
.
eye
(
3
,
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
pbc
=
torch
.
zeros
(
3
,
dtype
=
torch
.
uint8
,
device
=
self
.
EtaR
.
device
)
else
:
assert
len
(
input
)
==
4
species
,
coordinates
,
cell
,
pbc
=
input
return
species
,
compute_aev
(
species
,
coordinates
,
cell
,
pbc
,
self
.
triu_index
,
self
.
constants
(),
self
.
sizes
)
torchani/ase.py
View file @
35531421
...
...
@@ -6,95 +6,13 @@
"""
from
__future__
import
absolute_import
import
math
import
torch
import
ase.neighborlist
from
.
import
utils
import
ase.calculators.calculator
import
ase.units
import
copy
class
NeighborList
(
torch
.
nn
.
Module
):
"""ASE neighborlist computer
Arguments:
cell: same as in :class:`ase.Atoms`
pbc: same as in :class:`ase.Atoms`
"""
def
__init__
(
self
,
cell
=
None
,
pbc
=
None
):
# wrap `cell` and `pbc` with `ase.Atoms`
super
(
NeighborList
,
self
).
__init__
()
a
=
ase
.
Atoms
(
'He'
,
[[
0
,
0
,
0
]],
cell
=
cell
,
pbc
=
pbc
)
self
.
pbc
=
a
.
get_pbc
()
self
.
cell
=
a
.
get_cell
(
complete
=
True
)
def
forward
(
self
,
species
,
coordinates
,
cutoff
):
conformations
=
species
.
shape
[
0
]
max_atoms
=
species
.
shape
[
1
]
neighbor_species
=
[]
neighbor_distances
=
[]
neighbor_vecs
=
[]
for
i
in
range
(
conformations
):
s
=
species
[
i
].
unsqueeze
(
0
)
c
=
coordinates
[
i
].
unsqueeze
(
0
)
s
,
c
=
utils
.
strip_redundant_padding
(
s
,
c
)
s
=
s
.
squeeze
()
c
=
c
.
squeeze
()
atoms
=
s
.
shape
[
0
]
atoms_object
=
ase
.
Atoms
(
[
'He'
]
*
atoms
,
# chemical symbols are not important here
positions
=
c
.
detach
().
numpy
(),
pbc
=
self
.
pbc
,
cell
=
self
.
cell
)
idx1
,
idx2
,
shift
=
ase
.
neighborlist
.
neighbor_list
(
'ijS'
,
atoms_object
,
cutoff
)
# NB: The absolute distance and distance vectors computed by
# `neighbor_list`can not be used since it does not preserve
# gradient information
idx1
=
torch
.
tensor
(
idx1
,
device
=
coordinates
.
device
,
dtype
=
torch
.
long
)
idx2
=
torch
.
tensor
(
idx2
,
device
=
coordinates
.
device
,
dtype
=
torch
.
long
)
D
=
c
.
index_select
(
0
,
idx2
)
-
c
.
index_select
(
0
,
idx1
)
shift
=
torch
.
tensor
(
shift
,
device
=
coordinates
.
device
,
dtype
=
coordinates
.
dtype
)
cell
=
torch
.
tensor
(
self
.
cell
,
device
=
coordinates
.
device
,
dtype
=
coordinates
.
dtype
)
D
+=
torch
.
mm
(
shift
,
cell
)
d
=
D
.
norm
(
2
,
-
1
)
neighbor_species1
=
[]
neighbor_distances1
=
[]
neighbor_vecs1
=
[]
for
i
in
range
(
atoms
):
this_atom_indices
=
(
idx1
==
i
).
nonzero
().
flatten
()
neighbor_indices
=
idx2
[
this_atom_indices
]
neighbor_species1
.
append
(
s
[
neighbor_indices
])
neighbor_distances1
.
append
(
d
[
this_atom_indices
])
neighbor_vecs1
.
append
(
D
.
index_select
(
0
,
this_atom_indices
))
for
i
in
range
(
max_atoms
-
atoms
):
neighbor_species1
.
append
(
torch
.
full
((
1
,),
-
1
))
neighbor_distances1
.
append
(
torch
.
full
((
1
,),
math
.
inf
))
neighbor_vecs1
.
append
(
torch
.
full
((
1
,
3
),
0
))
neighbor_species1
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
neighbor_species1
,
padding_value
=-
1
)
neighbor_distances1
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
neighbor_distances1
,
padding_value
=
math
.
inf
)
neighbor_vecs1
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
neighbor_vecs1
,
padding_value
=
0
)
neighbor_species
.
append
(
neighbor_species1
)
neighbor_distances
.
append
(
neighbor_distances1
)
neighbor_vecs
.
append
(
neighbor_vecs1
)
neighbor_species
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
neighbor_species
,
batch_first
=
True
,
padding_value
=-
1
)
neighbor_distances
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
neighbor_distances
,
batch_first
=
True
,
padding_value
=
math
.
inf
)
neighbor_vecs
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
neighbor_vecs
,
batch_first
=
True
,
padding_value
=
0
)
return
neighbor_species
.
permute
(
0
,
2
,
1
),
\
neighbor_distances
.
permute
(
0
,
2
,
1
),
\
neighbor_vecs
.
permute
(
0
,
2
,
1
,
3
)
import
numpy
class
Calculator
(
ase
.
calculators
.
calculator
.
Calculator
):
...
...
@@ -109,16 +27,12 @@ class Calculator(ase.calculators.calculator.Calculator):
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
"""
implemented_properties
=
[
'energy'
,
'forces'
]
def
__init__
(
self
,
species
,
aev_computer
,
model
,
energy_shifter
,
dtype
=
torch
.
float64
,
_default_neighborlist
=
False
):
def
__init__
(
self
,
species
,
aev_computer
,
model
,
energy_shifter
,
dtype
=
torch
.
float64
):
super
(
Calculator
,
self
).
__init__
()
self
.
_default_neighborlist
=
_default_neighborlist
self
.
species_to_tensor
=
utils
.
ChemicalSymbolsToInts
(
species
)
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
...
...
@@ -138,16 +52,18 @@ class Calculator(ase.calculators.calculator.Calculator):
def
calculate
(
self
,
atoms
=
None
,
properties
=
[
'energy'
],
system_changes
=
ase
.
calculators
.
calculator
.
all_changes
):
super
(
Calculator
,
self
).
calculate
(
atoms
,
properties
,
system_changes
)
if
not
self
.
_default_neighborlist
:
self
.
aev_computer
.
neighborlist
.
pbc
=
self
.
atoms
.
get_pbc
()
self
.
aev_computer
.
neighborlist
.
cell
=
\
self
.
atoms
.
get_cell
(
complete
=
True
)
cell
=
torch
.
tensor
(
self
.
atoms
.
get_cell
(
complete
=
True
),
requires_grad
=
True
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
pbc
=
torch
.
tensor
(
self
.
atoms
.
get_pbc
().
astype
(
numpy
.
uint8
),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
# print(cell, pbc)
species
=
self
.
species_to_tensor
(
self
.
atoms
.
get_chemical_symbols
())
species
=
species
.
unsqueeze
(
0
)
coordinates
=
torch
.
tensor
(
self
.
atoms
.
get_positions
())
coordinates
=
coordinates
.
unsqueeze
(
0
).
to
(
self
.
device
).
to
(
self
.
dtype
)
\
.
requires_grad_
(
'forces'
in
properties
)
_
,
energy
=
self
.
whole
((
species
,
coordinates
))
_
,
energy
=
self
.
whole
((
species
,
coordinates
,
cell
,
pbc
))
energy
*=
ase
.
units
.
Hartree
self
.
results
[
'energy'
]
=
energy
.
item
()
if
'forces'
in
properties
:
...
...
torchani/utils.py
View file @
35531421
...
...
@@ -65,7 +65,7 @@ def pad_coordinates(species_coordinates):
return
torch
.
cat
(
species
),
torch
.
cat
(
coordinates
)
@
torch
.
jit
.
script
#
@torch.jit.script
def
present_species
(
species
):
"""Given a vector of species of atoms, compute the unique species present.
...
...
@@ -75,7 +75,8 @@ def present_species(species):
Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted.
"""
present_species
,
_
=
species
.
flatten
().
_unique
(
sorted
=
True
)
# present_species, _ = species.flatten()._unique(sorted=True)
present_species
=
species
.
flatten
().
unique
(
sorted
=
True
)
if
present_species
[
0
].
item
()
==
-
1
:
present_species
=
present_species
[
1
:]
return
present_species
...
...
@@ -86,9 +87,9 @@ def strip_redundant_padding(species, coordinates):
Arguments:
species (:class:`torch.Tensor`): Long tensor of shape
``(
conformation
s, atoms)``.
``(
molecule
s, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape
``(
conformation
s, atoms, 3)``.
``(
molecule
s, atoms, 3)``.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
...
...
@@ -100,6 +101,39 @@ def strip_redundant_padding(species, coordinates):
return
species
,
coordinates
def
map2central
(
cell
,
coordinates
,
pbc
):
"""Map atoms outside the unit cell into the cell using PBC.
Arguments:
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
if pbc is enabled for that direction.
Returns:
:class:`torch.Tensor`: coordinates of atoms mapped back to unit cell.
"""
# Step 1: convert coordinates from standard cartesian coordinate to unit
# cell coordinates
inv_cell
=
torch
.
inverse
(
cell
)
coordinates_cell
=
torch
.
matmul
(
coordinates
,
inv_cell
)
# Step 2: wrap cell coordinates into [0, 1)
coordinates_cell
-=
coordinates_cell
.
floor
()
*
pbc
.
to
(
coordinates_cell
.
dtype
)
# Step 3: convert from cell coordinates back to standard cartesian
# coordinate
return
torch
.
matmul
(
coordinates_cell
,
cell
)
class
EnergyShifter
(
torch
.
nn
.
Module
):
"""Helper class for adding and subtracting self atomic energies
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment