Unverified Commit b3744935 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

working ignite and dataloader with simple examples (#21)

parent ea718be0
from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir
from os import listdir
import os
from .pyanitools import anidataloader
from .env import default_dtype
from .env import default_dtype, default_device
import torch
import torch.utils.data as data
import pickle
class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
transform=(), dtype=default_dtype):
transform=(), dtype=default_dtype, device=default_device):
super(ANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
self.shuffle = shuffle
self.properties = properties
self.dtype = dtype
self.device = device
# get name of files storing data
files = []
if isdir(path):
for f in listdir(path):
for f in os.listdir(path):
f = join(path, f)
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
files.append(f)
......@@ -35,16 +38,17 @@ class ANIDataset(Dataset):
for m in anidataloader(f):
full = {
'coordinates': torch.from_numpy(m['coordinates'])
.type(dtype)
.type(dtype).to(device)
}
conformations = full['coordinates'].shape[0]
for i in properties:
full[i] = torch.from_numpy(m[i]).type(dtype)
full[i] = torch.from_numpy(m[i]).type(dtype).to(device)
species = m['species']
if shuffle:
indices = torch.randperm(conformations)
indices = torch.randperm(conformations, device=device)
else:
indices = torch.arange(conformations, dtype=torch.int64)
indices = torch.arange(conformations, dtype=torch.int64,
device=device)
num_chunks = (conformations + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_start = i * chunk_size
......@@ -66,6 +70,25 @@ class ANIDataset(Dataset):
return len(self.chunks)
def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if not os.path.isfile(checkpoint):
full_dataset = ANIDataset(dataset_path, chunk_size, *args, **kwargs)
training_size = int(len(full_dataset) * 0.8)
validation_size = int(len(full_dataset) * 0.1)
testing_size = len(full_dataset) - training_size - validation_size
lengths = [training_size, validation_size, testing_size]
subsets = data.random_split(full_dataset, lengths)
with open(checkpoint, 'wb') as f:
pickle.dump(subsets, f)
# load dataset from checkpoint file
with open(checkpoint, 'rb') as f:
training, validation, testing = pickle.load(f)
return training, validation, testing
def _collate(batch):
input_keys = ['coordinates', 'species']
inputs = [{k: i[k] for k in input_keys} for i in batch]
......
......@@ -30,7 +30,7 @@ class DictMetric(Metric):
self.metric.update((y_pred[self.key], y[self.key]))
def compute(self):
self.metric.compute()
return self.metric.compute()
energy_mse_loss = DictLoss('energies', torch.nn.MSELoss())
......
......@@ -31,7 +31,8 @@ class CustomModel(ANIModel):
raise ValueError(
'''output length of each atomic neural network must
match''')
setattr(self, 'model_' + i, model_X)
super(CustomModel, self).__init__(aev_computer, suffixes, reducer,
output_length, models, derivative,
derivative_graph, benchmark)
for i in per_species:
setattr(self, 'model_' + i, per_species[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment