Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
81e6150c
Unverified
Commit
81e6150c
authored
Nov 07, 2019
by
Gao, Xiang
Committed by
GitHub
Nov 07, 2019
Browse files
__constants__ is deprecated by torch.jit (#378)
* __constants__ is deprecated * commit
parent
d32081e9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
7 deletions
+16
-7
torchani/aev.py
torchani/aev.py
+16
-7
No files found.
torchani/aev.py
View file @
81e6150c
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
import
math
import
math
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
from
torch.jit
import
Final
def
cutoff_cosine
(
distances
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
def
cutoff_cosine
(
distances
:
Tensor
,
cutoff
:
float
)
->
Tensor
:
...
@@ -226,9 +227,9 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
...
@@ -226,9 +227,9 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
def
compute_aev
(
species
:
Tensor
,
coordinates
:
Tensor
,
cell
:
Tensor
,
def
compute_aev
(
species
:
Tensor
,
coordinates
:
Tensor
,
cell
:
Tensor
,
shifts
:
Tensor
,
triu_index
:
Tensor
,
shifts
:
Tensor
,
triu_index
:
Tensor
,
constants
:
Tuple
[
float
,
Tensor
,
Tensor
,
float
,
Tensor
,
Tensor
,
Tensor
,
Tensor
],
constants
:
Tuple
[
float
,
Tensor
,
Tensor
,
float
,
Tensor
,
Tensor
,
Tensor
,
Tensor
],
sizes
:
Tuple
[
int
,
int
,
int
,
int
,
int
,
int
])
->
Tensor
:
sizes
:
Tuple
[
int
,
int
,
int
,
int
,
int
])
->
Tensor
:
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
=
constants
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
=
constants
num_species
,
radial_sublength
,
radial_length
,
angular_sublength
,
angular_length
,
aev_length
=
sizes
num_species
,
radial_sublength
,
radial_length
,
angular_sublength
,
angular_length
=
sizes
num_molecules
=
species
.
shape
[
0
]
num_molecules
=
species
.
shape
[
0
]
num_atoms
=
species
.
shape
[
1
]
num_atoms
=
species
.
shape
[
1
]
num_species_pairs
=
angular_length
//
angular_sublength
num_species_pairs
=
angular_length
//
angular_sublength
...
@@ -300,15 +301,24 @@ class AEVComputer(torch.nn.Module):
...
@@ -300,15 +301,24 @@ class AEVComputer(torch.nn.Module):
.. _ANI paper:
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
"""
__constants__
=
[
'Rcr'
,
'Rca'
,
'num_species'
,
'radial_sublength'
,
Rcr
:
Final
[
float
]
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
Rca
:
Final
[
float
]
'aev_length'
]
num_species
:
Final
[
int
]
radial_sublength
:
Final
[
int
]
radial_length
:
Final
[
int
]
angular_sublength
:
Final
[
int
]
angular_length
:
Final
[
int
]
aev_length
:
Final
[
int
]
sizes
:
Final
[
Tuple
[
int
,
int
,
int
,
int
,
int
]]
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
super
(
AEVComputer
,
self
).
__init__
()
super
(
AEVComputer
,
self
).
__init__
()
self
.
Rcr
=
Rcr
self
.
Rcr
=
Rcr
self
.
Rca
=
Rca
self
.
Rca
=
Rca
assert
Rca
<=
Rcr
,
"Current implementation of AEVComputer assumes Rca <= Rcr"
assert
Rca
<=
Rcr
,
"Current implementation of AEVComputer assumes Rca <= Rcr"
self
.
num_species
=
num_species
# convert constant tensors to a ready-to-broadcast shape
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
# shape convension (..., EtaR, ShfR)
self
.
register_buffer
(
'EtaR'
,
EtaR
.
view
(
-
1
,
1
))
self
.
register_buffer
(
'EtaR'
,
EtaR
.
view
(
-
1
,
1
))
...
@@ -319,7 +329,6 @@ class AEVComputer(torch.nn.Module):
...
@@ -319,7 +329,6 @@ class AEVComputer(torch.nn.Module):
self
.
register_buffer
(
'ShfA'
,
ShfA
.
view
(
1
,
1
,
-
1
,
1
))
self
.
register_buffer
(
'ShfA'
,
ShfA
.
view
(
1
,
1
,
-
1
,
1
))
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
num_species
=
num_species
# The length of radial subaev of a single species
# The length of radial subaev of a single species
self
.
radial_sublength
=
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
self
.
radial_sublength
=
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
# The length of full radial aev
# The length of full radial aev
...
@@ -330,7 +339,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -330,7 +339,7 @@ class AEVComputer(torch.nn.Module):
self
.
angular_length
=
(
self
.
num_species
*
(
self
.
num_species
+
1
))
//
2
*
self
.
angular_sublength
self
.
angular_length
=
(
self
.
num_species
*
(
self
.
num_species
+
1
))
//
2
*
self
.
angular_sublength
# The length of full aev
# The length of full aev
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
self
.
sizes
=
self
.
num_species
,
self
.
radial_sublength
,
self
.
radial_length
,
self
.
angular_sublength
,
self
.
angular_length
,
self
.
aev_length
self
.
sizes
=
self
.
num_species
,
self
.
radial_sublength
,
self
.
radial_length
,
self
.
angular_sublength
,
self
.
angular_length
self
.
register_buffer
(
'triu_index'
,
triu_index
(
num_species
).
to
(
device
=
self
.
EtaR
.
device
))
self
.
register_buffer
(
'triu_index'
,
triu_index
(
num_species
).
to
(
device
=
self
.
EtaR
.
device
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment