Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
1c0f0e76
Unverified
Commit
1c0f0e76
authored
Apr 08, 2019
by
Gao, Xiang
Committed by
GitHub
Apr 08, 2019
Browse files
Use index_add_ to replace scatter_add (#204)
parent
bc4ab994
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
torchani/aev.py
torchani/aev.py
+3
-3
No files found.
torchani/aev.py
View file @
1c0f0e76
...
@@ -289,8 +289,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
...
@@ -289,8 +289,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
radial_aev
=
radial_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species
,
radial_sublength
)
radial_aev
=
radial_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species
,
radial_sublength
)
index1
=
(
molecule_index
*
num_atoms
+
atom_index1
)
*
num_species
+
species2
index1
=
(
molecule_index
*
num_atoms
+
atom_index1
)
*
num_species
+
species2
index2
=
(
molecule_index
*
num_atoms
+
atom_index2
)
*
num_species
+
species1
index2
=
(
molecule_index
*
num_atoms
+
atom_index2
)
*
num_species
+
species1
radial_aev
.
scatter
_add_
(
0
,
index1
.
unsqueeze
(
1
).
expand
(
-
1
,
radial_sublength
)
,
radial_terms_
)
radial_aev
.
index
_add_
(
0
,
index1
,
radial_terms_
)
radial_aev
.
scatter
_add_
(
0
,
index2
.
unsqueeze
(
1
).
expand
(
-
1
,
radial_sublength
)
,
radial_terms_
)
radial_aev
.
index
_add_
(
0
,
index2
,
radial_terms_
)
radial_aev
=
radial_aev
.
reshape
(
num_molecules
,
num_atoms
,
radial_length
)
radial_aev
=
radial_aev
.
reshape
(
num_molecules
,
num_atoms
,
radial_length
)
# compute angular aev
# compute angular aev
...
@@ -302,7 +302,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
...
@@ -302,7 +302,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
angular_terms_
=
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vec1
,
vec2
)
angular_terms_
=
angular_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
vec1
,
vec2
)
angular_aev
=
angular_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species_pairs
,
angular_sublength
)
angular_aev
=
angular_terms_
.
new_zeros
(
num_molecules
*
num_atoms
*
num_species_pairs
,
angular_sublength
)
index
=
(
molecule_index
*
num_atoms
+
central_atom_index
)
*
num_species_pairs
+
triu_index
[
species1_
,
species2_
]
index
=
(
molecule_index
*
num_atoms
+
central_atom_index
)
*
num_species_pairs
+
triu_index
[
species1_
,
species2_
]
angular_aev
.
scatter
_add_
(
0
,
index
.
unsqueeze
(
1
).
expand
(
-
1
,
angular_sublength
)
,
angular_terms_
)
angular_aev
.
index
_add_
(
0
,
index
,
angular_terms_
)
angular_aev
=
angular_aev
.
reshape
(
num_molecules
,
num_atoms
,
angular_length
)
angular_aev
=
angular_aev
.
reshape
(
num_molecules
,
num_atoms
,
angular_length
)
return
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=-
1
)
return
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment