Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
39137175
Unverified
Commit
39137175
authored
Oct 05, 2018
by
Gao, Xiang
Committed by
GitHub
Oct 05, 2018
Browse files
Misc improvements (#117)
parent
e4fe2a5c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
10 deletions
+5
-10
tests/test_ensemble.py
tests/test_ensemble.py
+1
-1
tests/test_forces.py
tests/test_forces.py
+2
-2
torchani/aev.py
torchani/aev.py
+2
-7
No files found.
tests/test_ensemble.py
View file @
39137175
...
@@ -16,7 +16,7 @@ class TestEnsemble(unittest.TestCase):
...
@@ -16,7 +16,7 @@ class TestEnsemble(unittest.TestCase):
def
_test_molecule
(
self
,
coordinates
,
species
):
def
_test_molecule
(
self
,
coordinates
,
species
):
builtins
=
torchani
.
neurochem
.
Builtins
()
builtins
=
torchani
.
neurochem
.
Builtins
()
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
coordinates
.
requires_grad
_
(
True
)
aev
=
builtins
.
aev_computer
aev
=
builtins
.
aev_computer
ensemble
=
builtins
.
models
ensemble
=
builtins
.
models
models
=
[
torch
.
nn
.
Sequential
(
aev
,
m
)
for
m
in
ensemble
]
models
=
[
torch
.
nn
.
Sequential
(
aev
,
m
)
for
m
in
ensemble
]
...
...
tests/test_forces.py
View file @
39137175
...
@@ -22,7 +22,7 @@ class TestForce(unittest.TestCase):
...
@@ -22,7 +22,7 @@ class TestForce(unittest.TestCase):
datafile
=
os
.
path
.
join
(
path
,
'test_data/{}'
.
format
(
i
))
datafile
=
os
.
path
.
join
(
path
,
'test_data/{}'
.
format
(
i
))
with
open
(
datafile
,
'rb'
)
as
f
:
with
open
(
datafile
,
'rb'
)
as
f
:
coordinates
,
species
,
_
,
_
,
_
,
forces
=
pickle
.
load
(
f
)
coordinates
,
species
,
_
,
_
,
_
,
forces
=
pickle
.
load
(
f
)
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
coordinates
.
requires_grad
_
(
True
)
_
,
energies
=
self
.
model
((
species
,
coordinates
))
_
,
energies
=
self
.
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energies
.
sum
(),
derivative
=
torch
.
autograd
.
grad
(
energies
.
sum
(),
coordinates
)[
0
]
coordinates
)[
0
]
...
@@ -36,7 +36,7 @@ class TestForce(unittest.TestCase):
...
@@ -36,7 +36,7 @@ class TestForce(unittest.TestCase):
datafile
=
os
.
path
.
join
(
path
,
'test_data/{}'
.
format
(
i
))
datafile
=
os
.
path
.
join
(
path
,
'test_data/{}'
.
format
(
i
))
with
open
(
datafile
,
'rb'
)
as
f
:
with
open
(
datafile
,
'rb'
)
as
f
:
coordinates
,
species
,
_
,
_
,
_
,
forces
=
pickle
.
load
(
f
)
coordinates
,
species
,
_
,
_
,
_
,
forces
=
pickle
.
load
(
f
)
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
coordinates
.
requires_grad
_
(
True
)
species_coordinates
.
append
((
species
,
coordinates
))
species_coordinates
.
append
((
species
,
coordinates
))
coordinates_forces
.
append
((
coordinates
,
forces
))
coordinates_forces
.
append
((
coordinates
,
forces
))
species
,
coordinates
=
torchani
.
utils
.
pad_coordinates
(
species
,
coordinates
=
torchani
.
utils
.
pad_coordinates
(
...
...
torchani/aev.py
View file @
39137175
...
@@ -155,11 +155,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -155,11 +155,7 @@ class AEVComputer(torch.nn.Module):
"""Shape (conformations, atoms, atoms) storing Rij distances"""
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask
=
(
species
==
-
1
).
unsqueeze
(
1
)
padding_mask
=
(
species
==
-
1
).
unsqueeze
(
1
)
distances
=
torch
.
where
(
distances
=
distances
.
masked_fill
(
padding_mask
,
math
.
inf
)
padding_mask
,
torch
.
tensor
(
math
.
inf
,
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
),
distances
)
distances
,
indices
=
distances
.
sort
(
-
1
)
distances
,
indices
=
distances
.
sort
(
-
1
)
...
@@ -172,11 +168,10 @@ class AEVComputer(torch.nn.Module):
...
@@ -172,11 +168,10 @@ class AEVComputer(torch.nn.Module):
radial_terms
=
self
.
_radial_subaev_terms
(
distances
)
radial_terms
=
self
.
_radial_subaev_terms
(
distances
)
indices_a
=
indices
.
index_select
(
-
1
,
inRca
)
indices_a
=
indices
.
index_select
(
-
1
,
inRca
)
new_shape
=
list
(
indices_a
.
shape
)
+
[
3
]
# TODO: remove this workaround when gather support broadcasting
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
# https://github.com/pytorch/pytorch/pull/9532
_indices_a
=
indices_a
.
unsqueeze
(
-
1
).
expand
(
*
new_shape
)
_indices_a
=
indices_a
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
3
)
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
vec
=
self
.
_combinations
(
vec
,
-
2
)
vec
=
self
.
_combinations
(
vec
,
-
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment