"torchvision/vscode:/vscode.git/clone" did not exist on "cf78a29b68ef2e18513bc5156cc775ba4bd6dea6"
Unverified Commit d9c0130f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Refactor cache_aev and cache_sparse_aev (#232)

parent 2ec2fb6d
...@@ -276,13 +276,27 @@ class AEVCacheLoader(Dataset): ...@@ -276,13 +276,27 @@ class AEVCacheLoader(Dataset):
aev_path = os.path.join(self.disk_cache, str(index)) aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f: with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f) species_aevs = pickle.load(f)
for i, sa in enumerate(species_aevs):
species, aevs = self.decode_aev(*sa)
species_aevs[i] = (
species.to(self.dataset.device),
aevs.to(self.dataset.device)
)
return species_aevs, output return species_aevs, output
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
@staticmethod
def decode_aev(encoded_species, encoded_aev):
return encoded_species, encoded_aev
@staticmethod
def encode_aev(species, aev):
return species, aev
class SparseAEVCacheLoader(Dataset): class SparseAEVCacheLoader(AEVCacheLoader):
"""Build a factory for AEV. """Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training. The computation of AEV is the most time-consuming part of the training.
...@@ -294,61 +308,26 @@ class SparseAEVCacheLoader(Dataset): ...@@ -294,61 +308,26 @@ class SparseAEVCacheLoader(Dataset):
Arguments: Arguments:
disk_cache (str): Directory storing disk caches. disk_cache (str): Directory storing disk caches.
device (:class:`torch.dtype`): device to put tensors when iterating.
""" """
def __init__(self, disk_cache=None, device=torch.device('cpu')): @staticmethod
super(SparseAEVCacheLoader, self).__init__() def decode_aev(encoded_species, encoded_aev):
self.disk_cache = disk_cache species = torch.from_numpy(encoded_species.todense())
self.device = device aevs_np = np.stack([np.array(i.todense()) for i in encoded_aev], axis=0)
# load dataset from disk cache aevs = torch.from_numpy(aevs_np)
dataset_path = os.path.join(disk_cache, 'dataset') return species, aevs
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
def __getitem__(self, index): @staticmethod
_, output = self.dataset.batches[index] def encode_aev(species, aev):
aev_path = os.path.join(self.disk_cache, str(index)) encoded_species = bsr_matrix(species.cpu().numpy())
with open(aev_path, 'rb') as f: encoded_aev = [bsr_matrix(i.cpu().numpy()) for i in aev]
species_aevs = pickle.load(f) return encoded_species, encoded_aev
batch_X = []
for species_, aev_ in species_aevs:
species_np = np.array(species_.todense())
species = torch.from_numpy(species_np).to(self.device)
aevs_np = np.stack([np.array(i.todense()) for i in aev_], axis=0)
aevs = torch.from_numpy(aevs_np).to(self.device)
batch_X.append((species, aevs))
return batch_X, output
def __len__(self):
return len(self.dataset)
builtin = neurochem.Builtins() builtin = neurochem.Builtins()
def cache_aev(output, dataset_path, batchsize, device=default_device, def create_aev_cache(dataset, aev_computer, output, enable_tqdm=True, encoder=lambda x: x):
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = BatchedANIDataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
# dump out the dataset # dump out the dataset
filename = os.path.join(output, 'dataset') filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
...@@ -361,15 +340,14 @@ def cache_aev(output, dataset_path, batchsize, device=default_device, ...@@ -361,15 +340,14 @@ def cache_aev(output, dataset_path, batchsize, device=default_device,
indices = range(len(dataset)) indices = range(len(dataset))
for i in indices: for i in indices:
input_, _ = dataset[i] input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_] aevs = [encoder(*aev_computer(j)) for j in input_]
filename = os.path.join(output, '{}'.format(i)) filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
pickle.dump(aevs, f) pickle.dump(aevs, f)
def cache_sparse_aev(output, dataset_path, batchsize, device=default_device, def _cache_aev(output, dataset_path, batchsize, device, constfile,
constfile=builtin.const_file, subtract_sae=False, subtract_sae, sae_file, enable_tqdm, encoder, **kwargs):
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it # if output directory does not exist, then create it
if not os.path.exists(output): if not os.path.exists(output):
os.makedirs(output) os.makedirs(output)
...@@ -389,27 +367,23 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device, ...@@ -389,27 +367,23 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
device=device, transform=transform, **kwargs device=device, transform=transform, **kwargs
) )
# dump out the dataset create_aev_cache(dataset, aev_computer, output, enable_tqdm, encoder)
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
if enable_tqdm:
import tqdm def cache_aev(output, dataset_path, batchsize, device=default_device,
indices = tqdm.trange(len(dataset)) constfile=builtin.const_file, subtract_sae=False,
else: sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
indices = range(len(dataset)) _cache_aev(output, dataset_path, batchsize, device, constfile,
for i in indices: subtract_sae, sae_file, enable_tqdm, AEVCacheLoader.encode_aev,
input_, _ = dataset[i] **kwargs)
aevs = []
for j in input_:
species_, aev_ = aev_computer(j) def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
species_ = bsr_matrix(species_.cpu().numpy()) constfile=builtin.const_file, subtract_sae=False,
aev_ = [bsr_matrix(i.cpu().numpy()) for i in aev_] sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
aevs.append((species_, aev_)) _cache_aev(output, dataset_path, batchsize, device, constfile,
filename = os.path.join(output, '{}'.format(i)) subtract_sae, sae_file, enable_tqdm,
with open(filename, 'wb') as f: SparseAEVCacheLoader.encode_aev, **kwargs)
pickle.dump(aevs, f)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev'] __all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment