Unverified Commit 18e4867d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add ignite API helpers, also temporarily disable JIT and python2 compatibiltiy (#29)

parent 965efee2
......@@ -12,4 +12,4 @@ steps:
commands:
- flake8
- python setup.py test
- python2 setup.py test
\ No newline at end of file
# - python2 setup.py test
\ No newline at end of file
......@@ -4,7 +4,7 @@ cmdclass = {'build_sphinx': BuildDoc}
setup(name='torchani',
version='0.1',
description='ANI based on pytorch',
description='PyTorch implementation of ANI',
url='https://github.com/zasdfgbnm/torchani',
author='Xiang Gao',
author_email='qasdfgtyuiop@ufl.edu',
......@@ -13,6 +13,7 @@ setup(name='torchani',
include_package_data=True,
install_requires=[
'torch',
'pytorch-ignite',
'lark-parser',
'h5py',
],
......
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
from ignite.engine import create_supervised_trainer
import torchani
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset/ani_gdb_s01.h5')
chunksize = 32
batch_chunks = 32
dtype = torch.float32
device = torch.device('cpu')
class TestIgnite(unittest.TestCase):
def testIgnite(self):
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
loss = torchani.ignite.DictLosses({'energies': torch.nn.MSELoss()})
optimizer = torch.optim.SGD(container.parameters(),
lr=0.001, momentum=0.8)
trainer = create_supervised_trainer(container, optimizer, loss)
trainer.run(loader, max_epochs=10)
if __name__ == '__main__':
unittest.main()
from .energyshifter import EnergyShifter
from . import models
from . import data
from . import ignite
from .aev import SortedAEV
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble, default_dtype, default_device
__all__ = ['SortedAEV', 'EnergyShifter', 'models', 'data',
__all__ = ['SortedAEV', 'EnergyShifter', 'models', 'data', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble',
'default_dtype', 'default_device']
......@@ -3,7 +3,7 @@ import timeit
import functools
class BenchmarkedModule(torch.jit.ScriptModule):
class BenchmarkedModule(torch.nn.Module):
"""Module with member function benchmarking support.
The benchmarking is done by wrapping the original member function with
......
......@@ -2,18 +2,20 @@ from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir
from os import listdir
from .pyanitools import anidataloader
from .env import default_dtype
import torch
class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True,
properties=['energies']):
properties=['energies'], dtype=default_dtype):
super(ANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
self.shuffle = shuffle
self.properties = properties
self.dtype = dtype
# get name of files storing data
files = []
......@@ -33,10 +35,11 @@ class ANIDataset(Dataset):
for m in anidataloader(f):
full = {
'coordinates': torch.from_numpy(m['coordinates'])
.type(dtype)
}
conformations = full['coordinates'].shape[0]
for i in properties:
full[i] = torch.from_numpy(m[i])
full[i] = torch.from_numpy(m[i]).type(dtype)
species = m['species']
if shuffle:
indices = torch.randperm(conformations)
......
from .container import Container
from .dict_loss import DictLosses
__all__ = ['Container', 'DictLosses']
import torch
from ..models import BatchModel
class Container(torch.nn.Module):
def __init__(self, models):
super(Container, self).__init__()
self.keys = models.keys()
for i in models:
if not isinstance(models[i], BatchModel):
raise ValueError('Container must contain batch models')
setattr(self, 'model_' + i, models[i])
def forward(self, batch):
output = {}
for i in self.keys:
model = getattr(self, 'model_' + i)
output[i] = model(batch)
return output
from torch.nn.modules.loss import _Loss
class DictLosses(_Loss):
def __init__(self, losses):
super(DictLosses, self).__init__()
self.losses = losses
def forward(self, input, other):
total = 0
for i in self.losses:
total += self.losses[i](input[i], other[i])
return total
......@@ -7,7 +7,7 @@ import math
import struct
class NeuroChemAtomicNetwork(torch.jit.ScriptModule):
class NeuroChemAtomicNetwork(torch.nn.Module):
"""Per atom aev->y transformation, loaded from NeuroChem network dir.
Attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment