Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
bacde2e7
Unverified
Commit
bacde2e7
authored
Apr 08, 2019
by
Gao, Xiang
Committed by
GitHub
Apr 08, 2019
Browse files
Remove unnecessary molecule index (#206)
parent
1c0f0e76
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
27 deletions
+22
-27
tests/test_aev.py
tests/test_aev.py
+4
-8
torchani/aev.py
torchani/aev.py
+18
-19
No files found.
tests/test_aev.py
View file @
bacde2e7
...
@@ -176,8 +176,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
...
@@ -176,8 +176,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
for
xyz2
in
xyz2s
:
for
xyz2
in
xyz2s
:
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
to
(
torch
.
double
).
unsqueeze
(
0
)
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
to
(
torch
.
double
).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
...
@@ -194,8 +193,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
...
@@ -194,8 +193,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2
[
i
]
=
9.9
xyz2
[
i
]
=
9.9
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
...
@@ -215,8 +213,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
...
@@ -215,8 +213,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2
[
j
]
=
new_i
xyz2
[
j
]
=
new_i
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
...
@@ -231,8 +228,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
...
@@ -231,8 +228,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2
=
torch
.
tensor
([
10.0
,
0.1
,
0.1
],
dtype
=
torch
.
double
)
xyz2
=
torch
.
tensor
([
10.0
,
0.1
,
0.1
],
dtype
=
torch
.
double
)
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
coordinates
=
torch
.
stack
([
xyz1
,
xyz2
]).
unsqueeze
(
0
)
molecule_index
,
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
atom_index1
,
atom_index2
,
_
=
torchani
.
aev
.
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
allshifts
,
1
)
self
.
assertEqual
(
molecule_index
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index1
.
tolist
(),
[
0
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
self
.
assertEqual
(
atom_index2
.
tolist
(),
[
1
])
...
...
torchani/aev.py
View file @
bacde2e7
...
@@ -169,10 +169,11 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
...
@@ -169,10 +169,11 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
distances
.
masked_fill_
(
padding_mask
,
math
.
inf
)
distances
.
masked_fill_
(
padding_mask
,
math
.
inf
)
in_cutoff
=
(
distances
<=
cutoff
).
nonzero
()
in_cutoff
=
(
distances
<=
cutoff
).
nonzero
()
molecule_index
,
pair_index
=
in_cutoff
.
unbind
(
1
)
molecule_index
,
pair_index
=
in_cutoff
.
unbind
(
1
)
molecule_index
*=
num_atoms
atom_index1
=
p1_all
[
pair_index
]
atom_index1
=
p1_all
[
pair_index
]
atom_index2
=
p2_all
[
pair_index
]
atom_index2
=
p2_all
[
pair_index
]
shifts
=
shifts_all
.
index_select
(
0
,
pair_index
)
shifts
=
shifts_all
.
index_select
(
0
,
pair_index
)
return
molecule_index
,
atom_index1
,
atom_index2
,
shifts
return
molecule_index
+
atom_index1
,
molecule_index
+
atom_index2
,
shifts
# torch.jit.script
# torch.jit.script
...
@@ -219,7 +220,7 @@ def cumsum_from_zero(input_):
...
@@ -219,7 +220,7 @@ def cumsum_from_zero(input_):
# torch.jit.script
# torch.jit.script
def
triple_by_molecule
(
molecule_index
,
atom_index1
,
atom_index2
):
def
triple_by_molecule
(
atom_index1
,
atom_index2
):
"""Input: indices for pairs of atoms that are close to each other.
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
(2, 1) exists.
...
@@ -230,24 +231,20 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
...
@@ -230,24 +231,20 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
"""
# convert representation from pair to central-other
# convert representation from pair to central-others
n
=
molecule_index
.
shape
[
0
]
n
=
atom_index1
.
shape
[
0
]
mi
=
molecule_index
.
repeat
(
2
)
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
ai1
=
torch
.
cat
([
atom_index1
,
atom_index2
])
# sort and compute unique key
# sort and compute unique key
mi_ai1
=
torch
.
stack
([
mi
,
ai1
],
dim
=
1
)
uniqued_central_atom_index
,
rev_indices
,
counts
=
torch
.
_unique2_temporary_will_remove_soon
(
ai1
,
sorted
=
True
,
return_inverse
=
True
,
return_counts
=
True
)
m_ac
,
rev_indices
,
counts
=
torch
.
_unique_dim2_temporary_will_remove_soon
(
mi_ai1
,
dim
=
0
,
sorted
=
True
,
return_inverse
=
True
,
return_counts
=
True
)
uniqued_molecule_index
,
uniqued_central_atom_index
=
m_ac
.
unbind
(
1
)
# do local combinations within unique key, assuming sorted
# do local combinations within unique key, assuming sorted
pair_sizes
=
counts
*
(
counts
-
1
)
//
2
pair_sizes
=
counts
*
(
counts
-
1
)
//
2
total_size
=
pair_sizes
.
sum
()
total_size
=
pair_sizes
.
sum
()
molecule_index
=
torch
.
repeat_interleave
(
uniqued_molecule_index
,
pair_sizes
)
central_atom_index
=
torch
.
repeat_interleave
(
uniqued_central_atom_index
,
pair_sizes
)
central_atom_index
=
torch
.
repeat_interleave
(
uniqued_central_atom_index
,
pair_sizes
)
cumsum
=
cumsum_from_zero
(
pair_sizes
)
cumsum
=
cumsum_from_zero
(
pair_sizes
)
cumsum
=
torch
.
repeat_interleave
(
cumsum
,
pair_sizes
)
cumsum
=
torch
.
repeat_interleave
(
cumsum
,
pair_sizes
)
sorted_local_pair_index
=
torch
.
arange
(
total_size
,
device
=
molecule_index
.
device
)
-
cumsum
sorted_local_pair_index
=
torch
.
arange
(
total_size
,
device
=
cumsum
.
device
)
-
cumsum
sorted_local_index1
,
sorted_local_index2
=
convert_pair_index
(
sorted_local_pair_index
)
sorted_local_index1
,
sorted_local_index2
=
convert_pair_index
(
sorted_local_pair_index
)
cumsum
=
cumsum_from_zero
(
counts
)
cumsum
=
cumsum_from_zero
(
counts
)
cumsum
=
torch
.
repeat_interleave
(
cumsum
,
pair_sizes
)
cumsum
=
torch
.
repeat_interleave
(
cumsum
,
pair_sizes
)
...
@@ -264,7 +261,7 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
...
@@ -264,7 +261,7 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
sign2
=
torch
.
where
(
local_index2
<
n
,
torch
.
ones_like
(
local_index2
),
-
torch
.
ones_like
(
local_index2
))
sign2
=
torch
.
where
(
local_index2
<
n
,
torch
.
ones_like
(
local_index2
),
-
torch
.
ones_like
(
local_index2
))
pair_index1
=
torch
.
where
(
local_index1
<
n
,
local_index1
,
local_index1
-
n
)
pair_index1
=
torch
.
where
(
local_index1
<
n
,
local_index1
,
local_index1
-
n
)
pair_index2
=
torch
.
where
(
local_index2
<
n
,
local_index2
,
local_index2
-
n
)
pair_index2
=
torch
.
where
(
local_index2
<
n
,
local_index2
,
local_index2
-
n
)
return
molecule_index
,
central_atom_index
,
pair_index1
,
pair_index2
,
sign1
,
sign2
return
central_atom_index
,
pair_index1
,
pair_index2
,
sign1
,
sign2
# torch.jit.script
# torch.jit.script
...
@@ -276,32 +273,34 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
...
@@ -276,32 +273,34 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
num_species_pairs
=
angular_length
//
angular_sublength
num_species_pairs
=
angular_length
//
angular_sublength
cutoff
=
max
(
Rcr
,
Rca
)
cutoff
=
max
(
Rcr
,
Rca
)
molecule_index
,
atom_index1
,
atom_index2
,
shifts
=
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
shifts
,
cutoff
)
atom_index1
,
atom_index2
,
shifts
=
neighbor_pairs
(
species
==
-
1
,
coordinates
,
cell
,
shifts
,
cutoff
)
species1
=
species
[
molecule_index
,
atom_index1
]
species
=
species
.
flatten
()
species2
=
species
[
molecule_index
,
atom_index2
]
coordinates
=
coordinates
.
flatten
(
0
,
1
)
species1
=
species
[
atom_index1
]
species2
=
species
[
atom_index2
]
shift_values
=
torch
.
mm
(
shifts
.
to
(
cell
.
dtype
),
cell
)
shift_values
=
torch
.
mm
(
shifts
.
to
(
cell
.
dtype
),
cell
)
vec
=
coordinates
[
molecule_index
,
atom_index1
,
:]
-
coordinates
[
molecule_index
,
atom_index2
,
:]
+
shift_values
vec
=
coordinates
.
index_select
(
0
,
atom_index1
)
-
coordinates
.
index_select
(
0
,
atom_index2
)
+
shift_values
distances
=
vec
.
norm
(
2
,
-
1
)
distances
=
vec
.
norm
(
2
,
-
1
)
# compute radial aev
# compute radial aev
radial_terms_
=
radial_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
)
radial_terms_
=
radial_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
)
radial_aev
=
radial_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species
,
radial_sublength
)
radial_aev
=
radial_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species
,
radial_sublength
)
index1
=
(
molecule_index
*
num_atoms
+
atom_index1
)
*
num_species
+
species2
index1
=
atom_index1
*
num_species
+
species2
index2
=
(
molecule_index
*
num_atoms
+
atom_index2
)
*
num_species
+
species1
index2
=
atom_index2
*
num_species
+
species1
radial_aev
.
index_add_
(
0
,
index1
,
radial_terms_
)
radial_aev
.
index_add_
(
0
,
index1
,
radial_terms_
)
radial_aev
.
index_add_
(
0
,
index2
,
radial_terms_
)
radial_aev
.
index_add_
(
0
,
index2
,
radial_terms_
)
radial_aev
=
radial_aev
.
reshape
(
num_molecules
,
num_atoms
,
radial_length
)
radial_aev
=
radial_aev
.
reshape
(
num_molecules
,
num_atoms
,
radial_length
)
# compute angular aev
# compute angular aev
molecule_index
,
central_atom_index
,
pair_index1
,
pair_index2
,
sign1
,
sign2
=
triple_by_molecule
(
molecule_index
,
atom_index1
,
atom_index2
)
central_atom_index
,
pair_index1
,
pair_index2
,
sign1
,
sign2
=
triple_by_molecule
(
atom_index1
,
atom_index2
)
vec1
=
vec
.
index_select
(
0
,
pair_index1
)
*
sign1
.
unsqueeze
(
1
).
to
(
vec
.
dtype
)
vec1
=
vec
.
index_select
(
0
,
pair_index1
)
*
sign1
.
unsqueeze
(
1
).
to
(
vec
.
dtype
)
vec2
=
vec
.
index_select
(
0
,
pair_index2
)
*
sign2
.
unsqueeze
(
1
).
to
(
vec
.
dtype
)
vec2
=
vec
.
index_select
(
0
,
pair_index2
)
*
sign2
.
unsqueeze
(
1
).
to
(
vec
.
dtype
)
species1_
=
torch
.
where
(
sign1
==
1
,
species2
[
pair_index1
],
species1
[
pair_index1
])
species1_
=
torch
.
where
(
sign1
==
1
,
species2
[
pair_index1
],
species1
[
pair_index1
])
species2_
=
torch
.
where
(
sign2
==
1
,
species2
[
pair_index2
],
species1
[
pair_index2
])
species2_
=
torch
.
where
(
sign2
==
1
,
species2
[
pair_index2
],
species1
[
pair_index2
])
angular_terms_
=
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vec1
,
vec2
)
angular_terms_
=
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vec1
,
vec2
)
angular_aev
=
angular_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species_pairs
,
angular_sublength
)
angular_aev
=
angular_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species_pairs
,
angular_sublength
)
index
=
(
molecule_index
*
num_atoms
+
central_atom_index
)
*
num_species_pairs
+
triu_index
[
species1_
,
species2_
]
index
=
central_atom_index
*
num_species_pairs
+
triu_index
[
species1_
,
species2_
]
angular_aev
.
index_add_
(
0
,
index
,
angular_terms_
)
angular_aev
.
index_add_
(
0
,
index
,
angular_terms_
)
angular_aev
=
angular_aev
.
reshape
(
num_molecules
,
num_atoms
,
angular_length
)
angular_aev
=
angular_aev
.
reshape
(
num_molecules
,
num_atoms
,
angular_length
)
return
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=-
1
)
return
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment