Unverified Commit b16a3234 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

use more specific ops for better performance (#60)

parent 7606fc3b
......@@ -360,8 +360,12 @@ class SortedAEV(AEVComputer):
n = tensor.shape[dim]
r = torch.arange(n).type(torch.long).to(tensor.device)
grid_x, grid_y = torch.meshgrid([r, r])
index1 = grid_y[torch.triu(torch.ones(n, n), diagonal=1) == 1]
index2 = grid_x[torch.triu(torch.ones(n, n), diagonal=1) == 1]
index1 = grid_y.masked_select(
torch.triu(torch.ones(n, n, device=self.EtaR.device),
diagonal=1) == 1)
index2 = grid_x.masked_select(
torch.triu(torch.ones(n, n, device=self.EtaR.device),
diagonal=1) == 1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
......@@ -481,9 +485,9 @@ class SortedAEV(AEVComputer):
radial_terms, angular_terms, indices_r, indices_a = \
self.terms_and_indices(coordinates)
species_r = species[indices_r]
species_r = species.take(indices_r)
mask_r = self.compute_mask_r(species_r)
species_a = species[indices_a]
species_a = species.take(indices_a)
mask_a = self.compute_mask_a(species_a, present_species)
radial, angular = self.assemble(radial_terms, angular_terms,
......
......@@ -5,7 +5,7 @@ from .pyanitools import anidataloader
import torch
import torch.utils.data as data
import pickle
import collections
import collections.abc
class ANIDataset(Dataset):
......@@ -96,11 +96,11 @@ def collate(batch):
no_collate = ['coordinates', 'species']
if isinstance(batch[0], torch.Tensor):
return torch.cat(batch)
elif isinstance(batch[0], collections.Mapping):
elif isinstance(batch[0], collections.abc.Mapping):
return {key: ((lambda x: x) if key in no_collate else collate)
([d[key] for d in batch])
for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
elif isinstance(batch[0], collections.abc.Sequence):
transposed = zip(*batch)
return [collate(samples) for samples in transposed]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment