Unverified Commit 1c0f0e76 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use index_add_ to replace scatter_add (#204)

parent bc4ab994
...@@ -289,8 +289,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -289,8 +289,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength) radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
index1 = (molecule_index * num_atoms + atom_index1) * num_species + species2 index1 = (molecule_index * num_atoms + atom_index1) * num_species + species2
index2 = (molecule_index * num_atoms + atom_index2) * num_species + species1 index2 = (molecule_index * num_atoms + atom_index2) * num_species + species1
radial_aev.scatter_add_(0, index1.unsqueeze(1).expand(-1, radial_sublength), radial_terms_) radial_aev.index_add_(0, index1, radial_terms_)
radial_aev.scatter_add_(0, index2.unsqueeze(1).expand(-1, radial_sublength), radial_terms_) radial_aev.index_add_(0, index2, radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length) radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# compute angular aev # compute angular aev
...@@ -302,7 +302,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -302,7 +302,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2) angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength) angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
index = (molecule_index * num_atoms + central_atom_index) * num_species_pairs + triu_index[species1_, species2_] index = (molecule_index * num_atoms + central_atom_index) * num_species_pairs + triu_index[species1_, species2_]
angular_aev.scatter_add_(0, index.unsqueeze(1).expand(-1, angular_sublength), angular_terms_) angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length) angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1) return torch.cat([radial_aev, angular_aev], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment