Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
98bb1237
Unverified
Commit
98bb1237
authored
Nov 07, 2019
by
Gao, Xiang
Committed by
GitHub
Nov 07, 2019
Browse files
Use modern type annotations for aev.py (#372)
* Use modern type annotations for aev.py * commit
parent
7499c8d4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
20 deletions
+17
-20
torchani/aev.py
torchani/aev.py
+17
-20
No files found.
torchani/aev.py
View file @
98bb1237
import
torch
from
torch
import
Tensor
import
math
from
typing
import
Tuple
,
Optional
def
cutoff_cosine
(
distances
,
cutoff
):
# type: (torch.Tensor, float) -> torch.Tensor
def
cutoff_cosine
(
distances
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
# assuming all elements in distances are smaller than cutoff
return
0.5
*
torch
.
cos
(
distances
*
(
math
.
pi
/
cutoff
))
+
0.5
def
radial_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def
radial_terms
(
Rcr
:
float
,
EtaR
:
Tensor
,
ShfR
:
Tensor
,
distances
:
Tensor
)
->
Tensor
:
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
...
...
@@ -36,8 +35,8 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return
ret
.
flatten
(
start_dim
=-
2
)
def
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vectors1
,
vectors2
):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor) ->
torch.
Tensor
def
angular_terms
(
Rca
:
float
,
ShfZ
:
Tensor
,
EtaA
:
Tensor
,
Zeta
:
Tensor
,
ShfA
:
Tensor
,
vectors1
:
Tensor
,
vectors2
:
Tensor
)
->
Tensor
:
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
...
...
@@ -72,8 +71,7 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return
ret
.
flatten
(
start_dim
=-
4
)
def
compute_shifts
(
cell
,
pbc
,
cutoff
):
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
def
compute_shifts
(
cell
:
Tensor
,
pbc
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
...
...
@@ -115,8 +113,8 @@ def compute_shifts(cell, pbc, cutoff):
])
def
neighbor_pairs
(
padding_mask
,
coordinates
,
cell
,
shifts
,
cutoff
):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
float) -> Tuple[
torch.
Tensor,
torch.
Tensor,
torch.
Tensor]
def
neighbor_pairs
(
padding_mask
:
Tensor
,
coordinates
:
Tensor
,
cell
:
Tensor
,
shifts
:
Tensor
,
cutoff
:
float
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]
:
"""Compute pairs of atoms that are neighbors
Arguments:
...
...
@@ -164,8 +162,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return
molecule_index
+
atom_index1
,
molecule_index
+
atom_index2
,
shifts
def
triu_index
(
num_species
):
# type: (int) -> torch.Tensor
def
triu_index
(
num_species
:
int
)
->
Tensor
:
species1
,
species2
=
torch
.
triu_indices
(
num_species
,
num_species
).
unbind
(
0
)
pair_index
=
torch
.
arange
(
species1
.
shape
[
0
],
dtype
=
torch
.
long
)
ret
=
torch
.
zeros
(
num_species
,
num_species
,
dtype
=
torch
.
long
)
...
...
@@ -174,15 +171,13 @@ def triu_index(num_species):
return
ret
def
cumsum_from_zero
(
input_
):
# type: (torch.Tensor) -> torch.Tensor
def
cumsum_from_zero
(
input_
:
Tensor
)
->
Tensor
:
cumsum
=
torch
.
cumsum
(
input_
,
dim
=
0
)
cumsum
=
torch
.
cat
([
input_
.
new_zeros
(
1
),
cumsum
[:
-
1
]])
return
cumsum
def
triple_by_molecule
(
atom_index1
,
atom_index2
):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
def
triple_by_molecule
(
atom_index1
:
Tensor
,
atom_index2
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
...
...
@@ -228,8 +223,10 @@ def triple_by_molecule(atom_index1, atom_index2):
return
central_atom_index
,
local_index1
%
n
,
local_index2
%
n
,
sign1
,
sign2
def
compute_aev
(
species
,
coordinates
,
cell
,
shifts
,
triu_index
,
constants
,
sizes
):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor
def
compute_aev
(
species
:
Tensor
,
coordinates
:
Tensor
,
cell
:
Tensor
,
shifts
:
Tensor
,
triu_index
:
Tensor
,
constants
:
Tuple
[
float
,
Tensor
,
Tensor
,
float
,
Tensor
,
Tensor
,
Tensor
,
Tensor
],
sizes
:
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
])
->
Tensor
:
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
=
constants
num_species
,
radial_sublength
,
radial_length
,
angular_sublength
,
angular_length
,
aev_length
=
sizes
num_molecules
=
species
.
shape
[
0
]
...
...
@@ -349,8 +346,8 @@ class AEVComputer(torch.nn.Module):
def
constants
(
self
):
return
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
def
forward
(
self
,
input_
,
cell
=
None
,
pbc
=
None
):
# type: (Tuple[torch.Tensor, torch.Tensor],
Optional[
torch.
Tensor]
, Optional[torch.Tensor]
) -> Tuple[
torch.
Tensor,
torch.
Tensor]
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
pbc
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Tensor
]
:
"""Compute AEVs
Arguments:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment