Unverified Commit d9c0130f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Refactor cache_aev and cache_sparse_aev (#232)

parent 2ec2fb6d
......@@ -276,13 +276,27 @@ class AEVCacheLoader(Dataset):
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
for i, sa in enumerate(species_aevs):
species, aevs = self.decode_aev(*sa)
species_aevs[i] = (
species.to(self.dataset.device),
aevs.to(self.dataset.device)
)
return species_aevs, output
def __len__(self):
return len(self.dataset)
@staticmethod
def decode_aev(encoded_species, encoded_aev):
return encoded_species, encoded_aev
class SparseAEVCacheLoader(Dataset):
@staticmethod
def encode_aev(species, aev):
return species, aev
class SparseAEVCacheLoader(AEVCacheLoader):
"""Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training.
......@@ -294,61 +308,26 @@ class SparseAEVCacheLoader(Dataset):
Arguments:
disk_cache (str): Directory storing disk caches.
device (:class:`torch.dtype`): device to put tensors when iterating.
"""
def __init__(self, disk_cache=None, device=torch.device('cpu')):
super(SparseAEVCacheLoader, self).__init__()
self.disk_cache = disk_cache
self.device = device
# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)
@staticmethod
def decode_aev(encoded_species, encoded_aev):
species = torch.from_numpy(encoded_species.todense())
aevs_np = np.stack([np.array(i.todense()) for i in encoded_aev], axis=0)
aevs = torch.from_numpy(aevs_np)
return species, aevs
def __getitem__(self, index):
_, output = self.dataset.batches[index]
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
batch_X = []
for species_, aev_ in species_aevs:
species_np = np.array(species_.todense())
species = torch.from_numpy(species_np).to(self.device)
aevs_np = np.stack([np.array(i.todense()) for i in aev_], axis=0)
aevs = torch.from_numpy(aevs_np).to(self.device)
batch_X.append((species, aevs))
return batch_X, output
def __len__(self):
return len(self.dataset)
@staticmethod
def encode_aev(species, aev):
encoded_species = bsr_matrix(species.cpu().numpy())
encoded_aev = [bsr_matrix(i.cpu().numpy()) for i in aev]
return encoded_species, encoded_aev
builtin = neurochem.Builtins()
def cache_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)
if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()
dataset = BatchedANIDataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)
def create_aev_cache(dataset, aev_computer, output, enable_tqdm=True, encoder=lambda x: x):
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
......@@ -361,15 +340,14 @@ def cache_aev(output, dataset_path, batchsize, device=default_device,
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [aev_computer(j) for j in input_]
aevs = [encoder(*aev_computer(j)) for j in input_]
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
def _cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm, encoder, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)
......@@ -389,27 +367,23 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
device=device, transform=transform, **kwargs
)
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)
create_aev_cache(dataset, aev_computer, output, enable_tqdm, encoder)
if enable_tqdm:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = []
for j in input_:
species_, aev_ = aev_computer(j)
species_ = bsr_matrix(species_.cpu().numpy())
aev_ = [bsr_matrix(i.cpu().numpy()) for i in aev_]
aevs.append((species_, aev_))
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)
def cache_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
_cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm, AEVCacheLoader.encode_aev,
**kwargs)
def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
constfile=builtin.const_file, subtract_sae=False,
sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
_cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm,
SparseAEVCacheLoader.encode_aev, **kwargs)
__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment