Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
8c493a6e
Unverified
Commit
8c493a6e
authored
Aug 01, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 01, 2018
Browse files
use 0 size n dim tensor feature to simplify code (#43)
parent
0e992fe5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
24 deletions
+12
-24
torchani/aev.py
torchani/aev.py
+12
-24
No files found.
torchani/aev.py
View file @
8c493a6e
...
@@ -356,8 +356,7 @@ class SortedAEV(AEVComputer):
...
@@ -356,8 +356,7 @@ class SortedAEV(AEVComputer):
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
# TODO: can we move combinations to ATen?
# TODO: can we move combinations to ATen?
vec
=
self
.
combinations
(
vec
,
-
2
)
vec
=
self
.
combinations
(
vec
,
-
2
)
angular_terms
=
self
.
angular_subaev_terms
(
angular_terms
=
self
.
angular_subaev_terms
(
*
vec
)
*
vec
)
if
vec
is
not
None
else
None
return
radial_terms
,
angular_terms
,
indices_r
,
indices_a
return
radial_terms
,
angular_terms
,
indices_r
,
indices_a
...
@@ -367,11 +366,6 @@ class SortedAEV(AEVComputer):
...
@@ -367,11 +366,6 @@ class SortedAEV(AEVComputer):
grid_x
,
grid_y
=
torch
.
meshgrid
([
r
,
r
])
grid_x
,
grid_y
=
torch
.
meshgrid
([
r
,
r
])
index1
=
grid_y
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
index1
=
grid_y
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
index2
=
grid_x
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
index2
=
grid_x
[
torch
.
triu
(
torch
.
ones
(
n
,
n
),
diagonal
=
1
)
==
1
]
if
torch
.
numel
(
index1
)
==
0
:
# TODO: pytorch are unable to handle size 0 tensor well.
# Is this an expected behavior?
# See: https://github.com/pytorch/pytorch/issues/5014
return
None
return
tensor
.
index_select
(
dim
,
index1
),
\
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
tensor
.
index_select
(
dim
,
index2
)
...
@@ -412,21 +406,16 @@ class SortedAEV(AEVComputer):
...
@@ -412,21 +406,16 @@ class SortedAEV(AEVComputer):
present species) storing the mask for each pair.
present species) storing the mask for each pair.
"""
"""
species_a
=
self
.
combinations
(
species_a
,
-
1
)
species_a
=
self
.
combinations
(
species_a
,
-
1
)
if
species_a
is
not
None
:
species_a1
,
species_a2
=
species_a
# TODO: can we remove this if pytorch support 0 size tensors?
species_a1
,
species_a2
=
species_a
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
if
species_a
is
not
None
:
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
==
present_species
)
present_species
).
unsqueeze
(
-
1
)
mask
=
mask_a1
*
mask_a2
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
==
present_species
)
mask_a
=
(
mask
+
mask_rev
)
>
0
mask
=
mask_a1
*
mask_a2
return
mask_a
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
(
mask
+
mask_rev
)
>
0
return
mask_a
else
:
return
None
def
assemble
(
self
,
radial_terms
,
angular_terms
,
present_species
,
def
assemble
(
self
,
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
):
mask_r
,
mask_a
):
...
@@ -480,8 +469,7 @@ class SortedAEV(AEVComputer):
...
@@ -480,8 +469,7 @@ class SortedAEV(AEVComputer):
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
dtype
,
device
=
self
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)),
2
):
range
(
len
(
self
.
species
)),
2
):
# TODO: can we remove this if pytorch support 0 size tensors?
if
s1
in
rev_indices
and
s2
in
rev_indices
:
if
s1
in
rev_indices
and
s2
in
rev_indices
and
mask_a
is
not
None
:
i1
=
rev_indices
[
s1
]
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
i2
=
rev_indices
[
s2
]
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment