".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e3efbc2d9094685dd2d4ae143853941f82f167af"
Unverified Commit 18e4867d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add ignite API helpers, also temporarily disable JIT and python2 compatibiltiy (#29)

parent 965efee2
...@@ -12,4 +12,4 @@ steps: ...@@ -12,4 +12,4 @@ steps:
commands: commands:
- flake8 - flake8
- python setup.py test - python setup.py test
- python2 setup.py test # - python2 setup.py test
\ No newline at end of file \ No newline at end of file
...@@ -4,7 +4,7 @@ cmdclass = {'build_sphinx': BuildDoc} ...@@ -4,7 +4,7 @@ cmdclass = {'build_sphinx': BuildDoc}
setup(name='torchani', setup(name='torchani',
version='0.1', version='0.1',
description='ANI based on pytorch', description='PyTorch implementation of ANI',
url='https://github.com/zasdfgbnm/torchani', url='https://github.com/zasdfgbnm/torchani',
author='Xiang Gao', author='Xiang Gao',
author_email='qasdfgtyuiop@ufl.edu', author_email='qasdfgtyuiop@ufl.edu',
...@@ -13,6 +13,7 @@ setup(name='torchani', ...@@ -13,6 +13,7 @@ setup(name='torchani',
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'torch', 'torch',
'pytorch-ignite',
'lark-parser', 'lark-parser',
'h5py', 'h5py',
], ],
......
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
from ignite.engine import create_supervised_trainer
import torchani
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'dataset/ani_gdb_s01.h5')
chunksize = 32
batch_chunks = 32
dtype = torch.float32
device = torch.device('cpu')
class TestIgnite(unittest.TestCase):
def testIgnite(self):
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
loss = torchani.ignite.DictLosses({'energies': torch.nn.MSELoss()})
optimizer = torch.optim.SGD(container.parameters(),
lr=0.001, momentum=0.8)
trainer = create_supervised_trainer(container, optimizer, loss)
trainer.run(loader, max_epochs=10)
if __name__ == '__main__':
unittest.main()
from .energyshifter import EnergyShifter from .energyshifter import EnergyShifter
from . import models from . import models
from . import data
from . import ignite
from .aev import SortedAEV from .aev import SortedAEV
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble, default_dtype, default_device buildin_model_prefix, buildin_ensemble, default_dtype, default_device
__all__ = ['SortedAEV', 'EnergyShifter', 'models', 'data', __all__ = ['SortedAEV', 'EnergyShifter', 'models', 'data', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir', 'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble', 'buildin_model_prefix', 'buildin_ensemble',
'default_dtype', 'default_device'] 'default_dtype', 'default_device']
...@@ -3,7 +3,7 @@ import timeit ...@@ -3,7 +3,7 @@ import timeit
import functools import functools
class BenchmarkedModule(torch.jit.ScriptModule): class BenchmarkedModule(torch.nn.Module):
"""Module with member function benchmarking support. """Module with member function benchmarking support.
The benchmarking is done by wrapping the original member function with The benchmarking is done by wrapping the original member function with
......
...@@ -2,18 +2,20 @@ from torch.utils.data import Dataset, DataLoader ...@@ -2,18 +2,20 @@ from torch.utils.data import Dataset, DataLoader
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
from os import listdir from os import listdir
from .pyanitools import anidataloader from .pyanitools import anidataloader
from .env import default_dtype
import torch import torch
class ANIDataset(Dataset): class ANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, def __init__(self, path, chunk_size, shuffle=True,
properties=['energies']): properties=['energies'], dtype=default_dtype):
super(ANIDataset, self).__init__() super(ANIDataset, self).__init__()
self.path = path self.path = path
self.chunks_size = chunk_size self.chunks_size = chunk_size
self.shuffle = shuffle self.shuffle = shuffle
self.properties = properties self.properties = properties
self.dtype = dtype
# get name of files storing data # get name of files storing data
files = [] files = []
...@@ -33,10 +35,11 @@ class ANIDataset(Dataset): ...@@ -33,10 +35,11 @@ class ANIDataset(Dataset):
for m in anidataloader(f): for m in anidataloader(f):
full = { full = {
'coordinates': torch.from_numpy(m['coordinates']) 'coordinates': torch.from_numpy(m['coordinates'])
.type(dtype)
} }
conformations = full['coordinates'].shape[0] conformations = full['coordinates'].shape[0]
for i in properties: for i in properties:
full[i] = torch.from_numpy(m[i]) full[i] = torch.from_numpy(m[i]).type(dtype)
species = m['species'] species = m['species']
if shuffle: if shuffle:
indices = torch.randperm(conformations) indices = torch.randperm(conformations)
......
from .container import Container
from .dict_loss import DictLosses
__all__ = ['Container', 'DictLosses']
import torch
from ..models import BatchModel
class Container(torch.nn.Module):
def __init__(self, models):
super(Container, self).__init__()
self.keys = models.keys()
for i in models:
if not isinstance(models[i], BatchModel):
raise ValueError('Container must contain batch models')
setattr(self, 'model_' + i, models[i])
def forward(self, batch):
output = {}
for i in self.keys:
model = getattr(self, 'model_' + i)
output[i] = model(batch)
return output
from torch.nn.modules.loss import _Loss
class DictLosses(_Loss):
def __init__(self, losses):
super(DictLosses, self).__init__()
self.losses = losses
def forward(self, input, other):
total = 0
for i in self.losses:
total += self.losses[i](input[i], other[i])
return total
...@@ -7,7 +7,7 @@ import math ...@@ -7,7 +7,7 @@ import math
import struct import struct
class NeuroChemAtomicNetwork(torch.jit.ScriptModule): class NeuroChemAtomicNetwork(torch.nn.Module):
"""Per atom aev->y transformation, loaded from NeuroChem network dir. """Per atom aev->y transformation, loaded from NeuroChem network dir.
Attributes Attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment