Unverified Commit b16a3234 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

use more specific ops for better performance (#60)

parent 7606fc3b
...@@ -360,8 +360,12 @@ class SortedAEV(AEVComputer): ...@@ -360,8 +360,12 @@ class SortedAEV(AEVComputer):
n = tensor.shape[dim] n = tensor.shape[dim]
r = torch.arange(n).type(torch.long).to(tensor.device) r = torch.arange(n).type(torch.long).to(tensor.device)
grid_x, grid_y = torch.meshgrid([r, r]) grid_x, grid_y = torch.meshgrid([r, r])
index1 = grid_y[torch.triu(torch.ones(n, n), diagonal=1) == 1] index1 = grid_y.masked_select(
index2 = grid_x[torch.triu(torch.ones(n, n), diagonal=1) == 1] torch.triu(torch.ones(n, n, device=self.EtaR.device),
diagonal=1) == 1)
index2 = grid_x.masked_select(
torch.triu(torch.ones(n, n, device=self.EtaR.device),
diagonal=1) == 1)
return tensor.index_select(dim, index1), \ return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2) tensor.index_select(dim, index2)
...@@ -481,9 +485,9 @@ class SortedAEV(AEVComputer): ...@@ -481,9 +485,9 @@ class SortedAEV(AEVComputer):
radial_terms, angular_terms, indices_r, indices_a = \ radial_terms, angular_terms, indices_r, indices_a = \
self.terms_and_indices(coordinates) self.terms_and_indices(coordinates)
species_r = species[indices_r] species_r = species.take(indices_r)
mask_r = self.compute_mask_r(species_r) mask_r = self.compute_mask_r(species_r)
species_a = species[indices_a] species_a = species.take(indices_a)
mask_a = self.compute_mask_a(species_a, present_species) mask_a = self.compute_mask_a(species_a, present_species)
radial, angular = self.assemble(radial_terms, angular_terms, radial, angular = self.assemble(radial_terms, angular_terms,
......
...@@ -5,7 +5,7 @@ from .pyanitools import anidataloader ...@@ -5,7 +5,7 @@ from .pyanitools import anidataloader
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import pickle import pickle
import collections import collections.abc
class ANIDataset(Dataset): class ANIDataset(Dataset):
...@@ -96,11 +96,11 @@ def collate(batch): ...@@ -96,11 +96,11 @@ def collate(batch):
no_collate = ['coordinates', 'species'] no_collate = ['coordinates', 'species']
if isinstance(batch[0], torch.Tensor): if isinstance(batch[0], torch.Tensor):
return torch.cat(batch) return torch.cat(batch)
elif isinstance(batch[0], collections.Mapping): elif isinstance(batch[0], collections.abc.Mapping):
return {key: ((lambda x: x) if key in no_collate else collate) return {key: ((lambda x: x) if key in no_collate else collate)
([d[key] for d in batch]) ([d[key] for d in batch])
for key in batch[0]} for key in batch[0]}
elif isinstance(batch[0], collections.Sequence): elif isinstance(batch[0], collections.abc.Sequence):
transposed = zip(*batch) transposed = zip(*batch)
return [collate(samples) for samples in transposed] return [collate(samples) for samples in transposed]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment