Unverified Commit b3744935 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

working ignite and dataloader with simple examples (#21)

parent ea718be0
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
from os import listdir import os
from .pyanitools import anidataloader from .pyanitools import anidataloader
from .env import default_dtype from .env import default_dtype, default_device
import torch import torch
import torch.utils.data as data
import pickle
class ANIDataset(Dataset): class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'], def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
transform=(), dtype=default_dtype): transform=(), dtype=default_dtype, device=default_device):
super(ANIDataset, self).__init__() super(ANIDataset, self).__init__()
self.path = path self.path = path
self.chunks_size = chunk_size self.chunks_size = chunk_size
self.shuffle = shuffle self.shuffle = shuffle
self.properties = properties self.properties = properties
self.dtype = dtype self.dtype = dtype
self.device = device
# get name of files storing data # get name of files storing data
files = [] files = []
if isdir(path): if isdir(path):
for f in listdir(path): for f in os.listdir(path):
f = join(path, f) f = join(path, f)
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')): if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
files.append(f) files.append(f)
...@@ -35,16 +38,17 @@ class ANIDataset(Dataset): ...@@ -35,16 +38,17 @@ class ANIDataset(Dataset):
for m in anidataloader(f): for m in anidataloader(f):
full = { full = {
'coordinates': torch.from_numpy(m['coordinates']) 'coordinates': torch.from_numpy(m['coordinates'])
.type(dtype) .type(dtype).to(device)
} }
conformations = full['coordinates'].shape[0] conformations = full['coordinates'].shape[0]
for i in properties: for i in properties:
full[i] = torch.from_numpy(m[i]).type(dtype) full[i] = torch.from_numpy(m[i]).type(dtype).to(device)
species = m['species'] species = m['species']
if shuffle: if shuffle:
indices = torch.randperm(conformations) indices = torch.randperm(conformations, device=device)
else: else:
indices = torch.arange(conformations, dtype=torch.int64) indices = torch.arange(conformations, dtype=torch.int64,
device=device)
num_chunks = (conformations + chunk_size - 1) // chunk_size num_chunks = (conformations + chunk_size - 1) // chunk_size
for i in range(num_chunks): for i in range(num_chunks):
chunk_start = i * chunk_size chunk_start = i * chunk_size
...@@ -66,6 +70,25 @@ class ANIDataset(Dataset): ...@@ -66,6 +70,25 @@ class ANIDataset(Dataset):
return len(self.chunks) return len(self.chunks)
def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if not os.path.isfile(checkpoint):
full_dataset = ANIDataset(dataset_path, chunk_size, *args, **kwargs)
training_size = int(len(full_dataset) * 0.8)
validation_size = int(len(full_dataset) * 0.1)
testing_size = len(full_dataset) - training_size - validation_size
lengths = [training_size, validation_size, testing_size]
subsets = data.random_split(full_dataset, lengths)
with open(checkpoint, 'wb') as f:
pickle.dump(subsets, f)
# load dataset from checkpoint file
with open(checkpoint, 'rb') as f:
training, validation, testing = pickle.load(f)
return training, validation, testing
def _collate(batch): def _collate(batch):
input_keys = ['coordinates', 'species'] input_keys = ['coordinates', 'species']
inputs = [{k: i[k] for k in input_keys} for i in batch] inputs = [{k: i[k] for k in input_keys} for i in batch]
......
...@@ -30,7 +30,7 @@ class DictMetric(Metric): ...@@ -30,7 +30,7 @@ class DictMetric(Metric):
self.metric.update((y_pred[self.key], y[self.key])) self.metric.update((y_pred[self.key], y[self.key]))
def compute(self): def compute(self):
self.metric.compute() return self.metric.compute()
energy_mse_loss = DictLoss('energies', torch.nn.MSELoss()) energy_mse_loss = DictLoss('energies', torch.nn.MSELoss())
......
...@@ -31,7 +31,8 @@ class CustomModel(ANIModel): ...@@ -31,7 +31,8 @@ class CustomModel(ANIModel):
raise ValueError( raise ValueError(
'''output length of each atomic neural network must '''output length of each atomic neural network must
match''') match''')
setattr(self, 'model_' + i, model_X)
super(CustomModel, self).__init__(aev_computer, suffixes, reducer, super(CustomModel, self).__init__(aev_computer, suffixes, reducer,
output_length, models, derivative, output_length, models, derivative,
derivative_graph, benchmark) derivative_graph, benchmark)
for i in per_species:
setattr(self, 'model_' + i, per_species[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment