Commit f12a27b4 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Bring new_zeros back (#353)

parent 123e4760
...@@ -138,7 +138,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): ...@@ -138,7 +138,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
# Step 2: center cell # Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1) p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = torch.zeros((p1_center.shape[0], 3), dtype=shifts.dtype, device=shifts.device) shifts_center = shifts.new_zeros((p1_center.shape[0], 3))
# Step 3: cells with shifts # Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3) # shape convention (shift index, molecule index, atom index, 3)
...@@ -205,7 +205,7 @@ def convert_pair_index(index): ...@@ -205,7 +205,7 @@ def convert_pair_index(index):
def cumsum_from_zero(input_): def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor # type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0) cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([torch.tensor([0], dtype=input_.dtype, device=input_.device), cumsum[:-1]]) cumsum = torch.cat([input_.new_zeros(1), cumsum[:-1]])
return cumsum return cumsum
...@@ -275,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -275,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
# compute radial aev # compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances) radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = torch.zeros((num_molecules * num_atoms * num_species, radial_sublength), dtype=radial_terms_.dtype, device=radial_terms_.device) radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength))
index1 = atom_index1 * num_species + species2 index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1 index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_) radial_aev.index_add_(0, index1, radial_terms_)
...@@ -298,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -298,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1]) species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2]) species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2) angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = torch.zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength), dtype=angular_terms_.dtype, device=angular_terms_.device) angular_aev = angular_terms_.new_zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength))
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_] index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_) angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length) angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment