Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
8c493a6e
Unverified
Commit
8c493a6e
authored
Aug 01, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 01, 2018
Browse files
use 0 size n dim tensor feature to simplify code (#43)
parent
0e992fe5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
24 deletions
+12
-24
torchani/aev.py
torchani/aev.py
+12
-24
No files found.
torchani/aev.py
View file @
8c493a6e
...
@@ -356,8 +356,7 @@ class SortedAEV(AEVComputer):
...
@@ -356,8 +356,7 @@ class SortedAEV(AEVComputer):
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
# TODO: can we move combinations to ATen?
# TODO: can we move combinations to ATen?
vec
=
self
.
combinations
(
vec
,
-
2
)
vec
=
self
.
combinations
(
vec
,
-
2
)
angular_terms
=
self
.
angular_subaev_terms
(
angular_terms
=
self
.
angular_subaev_terms
(
*
vec
)
*
vec
)
if
vec
is
not
None
else
None
return
radial_terms
,
angular_terms
,
indices_r
,
indices_a
return
radial_terms
,
angular_terms
,
indices_r
,
indices_a
...
@@ -367,11 +366,6 @@ class SortedAEV(AEVComputer):
...
@@ -367,11 +366,6 @@ class SortedAEV(AEVComputer):
grid_x
,
grid_y
=
torch
.
meshgrid
([
r
,
r
])
grid_x
,
grid_y
=
torch
.
meshgrid
([
r
,
r
])
index1
=
grid_y
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
index1
=
grid_y
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
index2
=
grid_x
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
index2
=
grid_x
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
if
torch
.
numel
(
index1
)
==
0
:
# TODO: pytorch are unable to handle size 0 tensor well.
# Is this an expected behavior?
# See: https://github.com/pytorch/pytorch/issues/5014
return
None
return
tensor
.
index_select
(
dim
,
index1
),
\
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
tensor
.
index_select
(
dim
,
index2
)
...
@@ -412,21 +406,16 @@ class SortedAEV(AEVComputer):
...
@@ -412,21 +406,16 @@ class SortedAEV(AEVComputer):
present species) storing the mask for each pair.
present species) storing the mask for each pair.
"""
"""
species_a
=
self
.
combinations
(
species_a
,
-
1
)
species_a
=
self
.
combinations
(
species_a
,
-
1
)
if
species_a
is
not
None
:
species_a1
,
species_a2
=
species_a
# TODO: can we remove this if pytorch support 0 size tensors?
species_a1
,
species_a2
=
species_a
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
if
species_a
is
not
None
:
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
==
present_species
)
present_species
).
unsqueeze
(
-
1
)
mask
=
mask_a1
*
mask_a2
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
==
present_species
)
mask_a
=
(
mask
+
mask_rev
)
>
0
mask
=
mask_a1
*
mask_a2
return
mask_a
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
(
mask
+
mask_rev
)
>
0
return
mask_a
else
:
return
None
def
assemble
(
self
,
radial_terms
,
angular_terms
,
present_species
,
def
assemble
(
self
,
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
):
mask_r
,
mask_a
):
...
@@ -480,8 +469,7 @@ class SortedAEV(AEVComputer):
...
@@ -480,8 +469,7 @@ class SortedAEV(AEVComputer):
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
dtype
,
device
=
self
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)),
2
):
range
(
len
(
self
.
species
)),
2
):
# TODO: can we remove this if pytorch support 0 size tensors?
if
s1
in
rev_indices
and
s2
in
rev_indices
:
if
s1
in
rev_indices
and
s2
in
rev_indices
and
mask_a
is
not
None
:
i1
=
rev_indices
[
s1
]
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
i2
=
rev_indices
[
s2
]
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment