Unverified Commit 87c4184a authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add style check to unit test (#12)

parent a61a1b3e
FROM zasdfgbnm/pytorch-master FROM zasdfgbnm/pytorch-master
RUN pacman -Sy --noconfirm python-sphinx python2-sphinx RUN pacman -Sy --noconfirm python-sphinx python2-sphinx flake8
COPY . /torchani COPY . /torchani
RUN cd torchani && pip install . RUN cd torchani && pip install .
RUN cd torchani && pip2 install . RUN cd torchani && pip2 install .
from benchmark import Benchmark
import torchani
class ANIBenchmark(Benchmark):
def __init__(self, device):
super(ANIBenchmark, self).__init__(device)
self.aev_computer = torchani.SortedAEV(device=device)
self.model = torchani.ModelOnAEV(
self.aev_computer, benchmark=True, derivative=True, from_nc=None)
def oneByOne(self, coordinates, species):
conformations = coordinates.shape[0]
coordinates = coordinates.to(self.device)
for i in range(conformations):
c = coordinates[i:i+1, :, :]
self.model(c, species)
ret = {
'aev': self.model.timers['aev'],
'energy': self.model.timers['nn'],
'force': self.model.timers['derivative']
}
self.model.reset_timers()
return ret
def inBatch(self, coordinates, species):
coordinates = coordinates.to(self.device)
self.model(coordinates, species)
ret = {
'aev': self.model.timers['aev'],
'energy': self.model.timers['nn'],
'force': self.model.timers['derivative']
}
self.model.reset_timers()
return ret
import numpy import torchani
class Benchmark: class ANIBenchmark:
"""Abstract class for benchmarking ANI implementations"""
def __init__(self, device): def __init__(self, device):
self.device = device super(ANIBenchmark, self).__init__(device)
self.aev_computer = torchani.SortedAEV(device=device)
self.model = torchani.ModelOnAEV(
self.aev_computer, benchmark=True, derivative=True, from_nc=None)
def oneByOne(self, coordinates, species): def oneByOne(self, coordinates, species):
"""Benchmarking the given dataset of computing energies and forces one at a time conformations = coordinates.shape[0]
coordinates = coordinates.to(self.device)
Parameters for i in range(conformations):
---------- c = coordinates[i:i+1, :, :]
coordinates : numpy.ndarray self.model(c, species)
Array of shape (conformations, atoms, 3) ret = {
species : list 'aev': self.model.timers['aev'],
List of species for this molecule. The length of the list must be the same as 'energy': self.model.timers['nn'],
atoms in the molecule. 'force': self.model.timers['derivative']
}
Returns self.model.reset_timers()
------- return ret
dict
Dictionary storing the times for computing AEVs, energies and forces, in seconds.
The dictionary should contain the following keys:
aev : the time used to compute AEVs from coordinates with given neighbor list.
energy : the time used to compute energies, when the AEVs are given.
force : the time used to compute forces, when the energies and AEVs are given.
"""
# return { 'neighborlist': 0, 'aev': 0, 'energy': 0, 'force': 0 }
raise NotImplementedError('subclass must implement this method')
def inBatch(self, coordinates, species): def inBatch(self, coordinates, species):
"""Benchmarking the given dataset of computing energies and forces in batch mode coordinates = coordinates.to(self.device)
self.model(coordinates, species)
The signature of this function is the same as `oneByOne`""" ret = {
# return { 'neighborlist': 0, 'aev': 0, 'energy': 0, 'force': 0 } 'aev': self.model.timers['aev'],
raise NotImplementedError('subclass must implement this method') 'energy': self.model.timers['nn'],
'force': self.model.timers['derivative']
}
self.model.reset_timers()
return ret
...@@ -2,7 +2,6 @@ from ase import Atoms ...@@ -2,7 +2,6 @@ from ase import Atoms
from ase.calculators.tip3p import TIP3P, rOH, angleHOH from ase.calculators.tip3p import TIP3P, rOH, angleHOH
from ase.md import Langevin from ase.md import Langevin
import ase.units as units import ase.units as units
from ase.io.trajectory import Trajectory
import numpy import numpy
import h5py import h5py
from rdkit import Chem from rdkit import Chem
...@@ -10,14 +9,12 @@ from rdkit.Chem import AllChem ...@@ -10,14 +9,12 @@ from rdkit.Chem import AllChem
# from asap3 import EMT # from asap3 import EMT
from ase.calculators.emt import EMT from ase.calculators.emt import EMT
from multiprocessing import Pool from multiprocessing import Pool
from tqdm import tqdm, tqdm_notebook, trange from tqdm import tqdm, trange
tqdm.monitor_interval = 0
from selected_system import mols, mol_file from selected_system import mols, mol_file
import functools
conformations = 1024 conformations = 1024
T = 30 T = 30
tqdm.monitor_interval = 0
fw = h5py.File("waters.hdf5", "w") fw = h5py.File("waters.hdf5", "w")
fm = h5py.File(mol_file, "w") fm = h5py.File(mol_file, "w")
...@@ -96,7 +93,7 @@ if __name__ == '__main__': ...@@ -96,7 +93,7 @@ if __name__ == '__main__':
print('done with molecules') print('done with molecules')
with Pool() as p: with Pool() as p:
p.starmap(waterbox, [(10, 10, 10, 0), (20, 20, 10, p.starmap(waterbox, [(10, 10, 10, 0), (20, 20, 10, 1),
1), (30, 30, 30, 2), (40, 40, 40, 3)]) (30, 30, 30, 2), (40, 40, 40, 3)])
print(list(fw.keys())) print(list(fw.keys()))
print('done with water boxes') print('done with water boxes')
mols = {
'20': [
'COC(=O)c1ccc([N+](=O)[O-])cc1',
'O=c1nnc2ccccc2n1CO',
'CCc1ccc([N+](=O)[O-])cc1',
'Nc1ccc(c2cnco2)cc1',
'COc1ccc(N)c(N)c1',
'O=C(O)CNc1ccccc1',
'NC(=O)NNc1ccccc1',
'Cn1c(=O)oc(=O)c2ccccc12',
'CC(=O)Nc1ccc(O)cc1',
'COc1ccc(CC#N)cc1'
],
'50': [
'O=[N+]([O-])c1ccc(NN=Cc2ccc(C=NNc3ccc([N+](=O)[O-])cc3[N+](=O)[O-])cc2)c([N+](=O)[O-])c1',
'CCCCCc1nccnc1OCC(C)(C)CC(C)C',
'CC(C)(C)c1ccc(N(C(=O)c2ccccc2)C(C)(C)C)cc1',
'CCCCCCCCCCCOC(=O)Nc1ccccc1',
'CC(=O)NCC(CN1CCCC1)(c1ccccc1)c1ccccc1',
'CCCCCc1cnc(C)c(OCC(C)(C)CCC)n1',
'CCCCCCCCCCCCN1CCOC(=O)C1',
'CCCCOc1ccc(C=Nc2ccc(CCCC)cc2)cc1',
'CC1CC(C)C(=NNC(=O)N)C(C(O)CC2CC(=O)NC(=O)C2)C1',
'CCCCCOc1ccc(C=Nc2ccc(C(=O)OCC)cc2)cc1'
],
'10': [
'N#CCC(=O)N',
'N#CCCO',
'O=C1NC(=O)C(=O)N1',
'COCC#N',
'N#CCNC=O',
'ON=CC=NO',
'NCC(=O)O',
'NC(=O)CO',
'N#Cc1ccco1',
'C=CC(=O)N'
],
'4,5,6': [
'C',
'C#CC#N',
'C=C',
'CC#N',
'C#CC#C',
'O=CC#C',
'C#C'
],
'100': [
'CC(C)C[C@@H](C(=O)O)NC(=O)C[C@@H]([C@H](CC1CCCCC1)NC(=O)CC[C@@H]([C@H](Cc2ccccc2)NC(=O)OC(C)(C)C)O)O',
'CC(C)(C)OC(=O)N[C@@H](Cc1ccccc1)[C@@H](CN[C@@H](Cc2ccccc2)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](Cc3ccccc3)C(=O)N)O',
'CC(C)(C)OC(=O)N[C@@H](Cc1ccccc1)[C@H](CN[C@@H](Cc2ccccc2)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](Cc3ccccc3)C(=O)N)O',
'CC[C@H](c1ccc(cc1)O)[C@H](c2ccc(cc2)O)C(=O)OCCCCCCCCOC(=O)C(c3ccc(cc3)O)C(CC)c4ccc(cc4)O',
'CC/C(=C\\CC[C@H](C)C[C@@H](C)CC[C@@H]([C@H](C)C(=O)C[C@H]([C@H](C)[C@@H](C)OC(=O)C[C@H](/C(=C(\\C)/C(=O)O)/C(=O)O)O)O)O)/C=C/C(=O)O',
'CC[C@H](C)[C@H]1C(=O)NCCCOc2ccc(cc2)C[C@@H](C(=O)N1)NC(=O)[C@@H]3Cc4ccc(cc4)OCCCCC(=O)N[C@H](C(=O)N3)C(C)C',
'CC(C)(C)CC(C)(C)c1ccc(cc1)OCCOCCOCCOCCOCCOCCOCCOCCOCCO',
'CCOC(=O)CC[C@H](C[C@@H]1CCNC1=O)NC(=O)[C@H](Cc2ccccc2)NC(=O)[C@H](CCC(=O)OC(C)(C)C)NC(=O)OCc3ccccc3',
'C[C@]12CC[C@@H]3c4ccc(cc4CC[C@H]3[C@@H]1C[C@@H]([C@@H]2O)CCCCCCCCC(=O)OC[C@@H]5[C@H]([C@H]([C@@H](O5)n6cnc7c6ncnc7N)O)O)O',
'c1cc(ccc1CCc2c[nH]c3c2C(=O)NC(=N3)N)C(=O)N[C@@H](CCC(=O)N[C@@H](CCC(=O)N[C@@H](CCC(=O)N[C@H](CCC(=O)O)C(=O)O)C(=O)O)C(=O)O)C(=O)O'
],
'305': [
'[H]/N=C(/N)\\NCCC[C@H](C(=O)N[C@H]([C@@H](C)O)C(=O)N[C@H](Cc1ccc(cc1)O)C(=O)NCCCC[C@@H](C(=O)NCCCC[C@@H](C(=O)NCC(=O)O)NC(=O)[C@H](CCCCNC(=O)[C@@H](Cc2ccc(cc2)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)NC(=O)[C@@H](Cc3ccc(cc3)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)NC(=O)[C@@H](Cc4ccc(cc4)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)N'
]
}
mol_file = "molecules.hdf5"
from selected_system import mols, mol_file
import h5py
import os
fm = h5py.File(os.path.join(mol_file), "r")
for i in mols:
print('number of atoms:', i)
smiles = mols[i]
for s in smiles:
key = s.replace('/', '_')
filename = i
with open('benchmark_xyz/' + filename + '.xyz', 'w') as fxyz:
coordinates = fm[key][()]
species = fm[key].attrs['species'].split()
conformations = coordinates.shape[0]
atoms = len(species)
for i in range(conformations):
fxyz.write('{}\n{}\n'.format(
atoms, 'smiles:{}\tconformation:{}'.format(s, i)))
for j in range(atoms):
ss = species[j]
xyz = coordinates[i, j, :]
x = xyz[0]
y = xyz[1]
z = xyz[2]
fxyz.write('{} {} {} {}\n'.format(ss, x, y, z))
break
...@@ -10,4 +10,5 @@ steps: ...@@ -10,4 +10,5 @@ steps:
unit-tests: unit-tests:
image: '${{build-torchani}}' image: '${{build-torchani}}'
commands: commands:
- flake8
- python setup.py test - python setup.py test
\ No newline at end of file
...@@ -3,14 +3,11 @@ import torchani ...@@ -3,14 +3,11 @@ import torchani
import torchani.data import torchani.data
import math import math
import timeit import timeit
import itertools
import os
import sys import sys
import pickle import pickle
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from common import * from common import get_or_create_model, Averager, evaluate
import sys
import json import json
chunk_size = 256 chunk_size = 256
...@@ -27,11 +24,14 @@ with open('data/dataset.dat', 'rb') as f: ...@@ -27,11 +24,14 @@ with open('data/dataset.dat', 'rb') as f:
testing, chunk_size, batch_chunks) testing, chunk_size, batch_chunks)
training_dataloader = torch.utils.data.DataLoader( training_dataloader = torch.utils.data.DataLoader(
training, batch_sampler=training_sampler, collate_fn=torchani.data.collate) training, batch_sampler=training_sampler,
collate_fn=torchani.data.collate)
validation_dataloader = torch.utils.data.DataLoader( validation_dataloader = torch.utils.data.DataLoader(
validation, batch_sampler=validation_sampler, collate_fn=torchani.data.collate) validation, batch_sampler=validation_sampler,
collate_fn=torchani.data.collate)
testing_dataloader = torch.utils.data.DataLoader( testing_dataloader = torch.utils.data.DataLoader(
testing, batch_sampler=testing_sampler, collate_fn=torchani.data.collate) testing, batch_sampler=testing_sampler,
collate_fn=torchani.data.collate)
writer = SummaryWriter('runs/adam-{}'.format(sys.argv[1])) writer = SummaryWriter('runs/adam-{}'.format(sys.argv[1]))
...@@ -48,8 +48,8 @@ def subset_rmse(subset_dataloader): ...@@ -48,8 +48,8 @@ def subset_rmse(subset_dataloader):
for molecule_id in batch: for molecule_id in batch:
_species = subset_dataloader.dataset.species[molecule_id] _species = subset_dataloader.dataset.species[molecule_id]
coordinates, energies = batch[molecule_id] coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device) coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(aev_computer.device) energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate( count, squared_error = evaluate(
model, coordinates, energies, _species) model, coordinates, energies, _species)
squared_error = squared_error.item() squared_error = squared_error.item()
...@@ -73,13 +73,15 @@ best_validation_rmse = math.inf ...@@ -73,13 +73,15 @@ best_validation_rmse = math.inf
best_epoch = 0 best_epoch = 0
start = timeit.default_timer() start = timeit.default_timer()
while True: while True:
for batch in tqdm(training_dataloader, desc='epoch {}'.format(epoch), total=len(training_sampler)): for batch in tqdm(training_dataloader,
desc='epoch {}'.format(epoch),
total=len(training_sampler)):
a = Averager() a = Averager()
for molecule_id in batch: for molecule_id in batch:
_species = training.species[molecule_id] _species = training.species[molecule_id]
coordinates, energies = batch[molecule_id] coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device) coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(aev_computer.device) energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate( count, squared_error = evaluate(
model, coordinates, energies, _species) model, coordinates, energies, _species)
a.add(count, squared_error / len(_species)) a.add(count, squared_error / len(_species))
......
...@@ -3,14 +3,10 @@ import torchani ...@@ -3,14 +3,10 @@ import torchani
import torchani.data import torchani.data
import math import math
import timeit import timeit
import itertools
import os
import sys
import pickle import pickle
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from common import * from common import get_or_create_model, Averager, evaluate
from copy import deepcopy
chunk_size = 256 chunk_size = 256
batch_chunks = 1024 // chunk_size batch_chunks = 1024 // chunk_size
...@@ -26,11 +22,14 @@ with open('data/dataset.dat', 'rb') as f: ...@@ -26,11 +22,14 @@ with open('data/dataset.dat', 'rb') as f:
testing, chunk_size, batch_chunks) testing, chunk_size, batch_chunks)
training_dataloader = torch.utils.data.DataLoader( training_dataloader = torch.utils.data.DataLoader(
training, batch_sampler=training_sampler, collate_fn=torchani.data.collate) training, batch_sampler=training_sampler,
collate_fn=torchani.data.collate)
validation_dataloader = torch.utils.data.DataLoader( validation_dataloader = torch.utils.data.DataLoader(
validation, batch_sampler=validation_sampler, collate_fn=torchani.data.collate) validation, batch_sampler=validation_sampler,
collate_fn=torchani.data.collate)
testing_dataloader = torch.utils.data.DataLoader( testing_dataloader = torch.utils.data.DataLoader(
testing, batch_sampler=testing_sampler, collate_fn=torchani.data.collate) testing, batch_sampler=testing_sampler,
collate_fn=torchani.data.collate)
writer = SummaryWriter() writer = SummaryWriter()
...@@ -47,8 +46,8 @@ def subset_rmse(subset_dataloader): ...@@ -47,8 +46,8 @@ def subset_rmse(subset_dataloader):
for molecule_id in batch: for molecule_id in batch:
_species = subset_dataloader.dataset.species[molecule_id] _species = subset_dataloader.dataset.species[molecule_id]
coordinates, energies = batch[molecule_id] coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device) coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(aev_computer.device) energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate(coordinates, energies, _species) count, squared_error = evaluate(coordinates, energies, _species)
squared_error = squared_error.item() squared_error = squared_error.item()
a.add(count, squared_error) a.add(count, squared_error)
...@@ -71,13 +70,14 @@ best_validation_rmse = math.inf ...@@ -71,13 +70,14 @@ best_validation_rmse = math.inf
best_epoch = 0 best_epoch = 0
start = timeit.default_timer() start = timeit.default_timer()
while True: while True:
for batch in tqdm(training_dataloader, desc='epoch {}'.format(epoch), total=len(training_sampler)): for batch in tqdm(training_dataloader, desc='epoch {}'.format(epoch),
total=len(training_sampler)):
a = Averager() a = Averager()
for molecule_id in batch: for molecule_id in batch:
_species = training.species[molecule_id] _species = training.species[molecule_id]
coordinates, energies = batch[molecule_id] coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device) coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(aev_computer.device) energies = energies.to(model.aev_computer.device)
count, squared_error = evaluate( count, squared_error = evaluate(
model, coordinates, energies, _species) model, coordinates, energies, _species)
a.add(count, squared_error / len(_species)) a.add(count, squared_error / len(_species))
......
...@@ -23,7 +23,8 @@ hyperparams = [ # (chunk size, batch chunks) ...@@ -23,7 +23,8 @@ hyperparams = [ # (chunk size, batch chunks)
] ]
for chunk_size, batch_chunks in hyperparams: for chunk_size, batch_chunks in hyperparams:
with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks), 'rb') as f: with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks),
'rb') as f:
ag, agsqr = pickle.load(f) ag, agsqr = pickle.load(f)
variance = torch.sum(agsqr) - torch.sum(ag**2) variance = torch.sum(agsqr) - torch.sum(ag**2)
stddev = torch.sqrt(variance).item() stddev = torch.sqrt(variance).item()
......
...@@ -3,18 +3,16 @@ import torch ...@@ -3,18 +3,16 @@ import torch
import torchani import torchani
import configs import configs
import torchani.data import torchani.data
import math
from tqdm import tqdm from tqdm import tqdm
import itertools
import os
import pickle import pickle
from common import get_or_create_model, Averager, evaluate
device = configs.device
if len(sys.argv) >= 2: if len(sys.argv) >= 2:
configs.device = torch.device(sys.argv[1]) device = torch.device(sys.argv[1])
from common import *
ds = torchani.data.load_dataset(configs.data_path) ds = torchani.data.load_dataset(configs.data_path)
model = get_or_create_model('/tmp/model.pt', device=device)
# just to conveniently zero grads # just to conveniently zero grads
optimizer = torch.optim.Adam(model.parameters()) optimizer = torch.optim.Adam(model.parameters())
...@@ -31,8 +29,8 @@ def batch_gradient(batch): ...@@ -31,8 +29,8 @@ def batch_gradient(batch):
for molecule_id in batch: for molecule_id in batch:
_species = ds.species[molecule_id] _species = ds.species[molecule_id]
coordinates, energies = batch[molecule_id] coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device) coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(aev_computer.device) energies = energies.to(model.aev_computer.device)
a.add(*evaluate(coordinates, energies, _species)) a.add(*evaluate(coordinates, energies, _species))
mse = a.avg() mse = a.avg()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -59,7 +57,8 @@ def compute(chunk_size, batch_chunks): ...@@ -59,7 +57,8 @@ def compute(chunk_size, batch_chunks):
agsqr.add(1, g**2) agsqr.add(1, g**2)
ag = ag.avg() ag = ag.avg()
agsqr = agsqr.avg() agsqr = agsqr.avg()
with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks), 'wb') as f: filename = 'data/avg-{}-{}.dat'.format(chunk_size, batch_chunks)
with open(filename, 'wb') as f:
pickle.dump((ag, agsqr), f) pickle.dump((ag, agsqr), f)
......
import torchani import torchani
import torch import torch
import os import os
from configs import benchmark, device import configs
class Averager: class Averager:
...@@ -18,17 +18,15 @@ class Averager: ...@@ -18,17 +18,15 @@ class Averager:
return self.subtotal / self.count return self.subtotal / self.count
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
def celu(x, alpha): def celu(x, alpha):
return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1)) return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1))
class AtomicNetwork(torch.nn.Module): class AtomicNetwork(torch.nn.Module):
def __init__(self): def __init__(self, aev_computer):
super(AtomicNetwork, self).__init__() super(AtomicNetwork, self).__init__()
self.aev_computer = aev_computer
self.output_length = 1 self.output_length = 1
self.layer1 = torch.nn.Linear(384, 128).type( self.layer1 = torch.nn.Linear(384, 128).type(
aev_computer.dtype).to(aev_computer.device) aev_computer.dtype).to(aev_computer.device)
...@@ -51,16 +49,17 @@ class AtomicNetwork(torch.nn.Module): ...@@ -51,16 +49,17 @@ class AtomicNetwork(torch.nn.Module):
return y return y
def get_or_create_model(filename): def get_or_create_model(filename, benchmark=False, device=configs.device):
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
model = torchani.ModelOnAEV( model = torchani.ModelOnAEV(
aev_computer, aev_computer,
reducer=torch.sum, reducer=torch.sum,
benchmark=benchmark, benchmark=benchmark,
per_species={ per_species={
'C': AtomicNetwork(), 'C': AtomicNetwork(aev_computer),
'H': AtomicNetwork(), 'H': AtomicNetwork(aev_computer),
'N': AtomicNetwork(), 'N': AtomicNetwork(aev_computer),
'O': AtomicNetwork(), 'O': AtomicNetwork(aev_computer),
}) })
if os.path.isfile(filename): if os.path.isfile(filename):
model.load_state_dict(torch.load(filename)) model.load_state_dict(torch.load(filename))
......
import torch import torch
benchmark = False
data_path = 'data/ANI-1x_complete.h5' data_path = 'data/ANI-1x_complete.h5'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# flake8: noqa
import torch import torch
import torchani import torchani
......
...@@ -2,18 +2,17 @@ import torch ...@@ -2,18 +2,17 @@ import torch
import torchani import torchani
import torchani.data import torchani.data
import tqdm import tqdm
import math
import timeit import timeit
import configs import configs
import functools import functools
configs.benchmark = True from common import get_or_create_model, Averager, evaluate
from common import *
ds = torchani.data.load_dataset(configs.data_path) ds = torchani.data.load_dataset(configs.data_path)
sampler = torchani.data.BatchSampler(ds, 256, 4) sampler = torchani.data.BatchSampler(ds, 256, 4)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
ds, batch_sampler=sampler, collate_fn=torchani.data.collate, num_workers=20) ds, batch_sampler=sampler,
model = get_or_create_model('/tmp/model.pt') collate_fn=torchani.data.collate, num_workers=20)
model = get_or_create_model('/tmp/model.pt', True)
optimizer = torch.optim.Adam(model.parameters(), amsgrad=True) optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)
...@@ -47,20 +46,20 @@ for batch in tqdm.tqdm(dataloader, total=len(sampler)): ...@@ -47,20 +46,20 @@ for batch in tqdm.tqdm(dataloader, total=len(sampler)):
for molecule_id in batch: for molecule_id in batch:
_species = ds.species[molecule_id] _species = ds.species[molecule_id]
coordinates, energies = batch[molecule_id] coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device) coordinates = coordinates.to(model.aev_computer.device)
energies = energies.to(aev_computer.device) energies = energies.to(model.aev_computer.device)
a.add(*evaluate(model, coordinates, energies, _species)) a.add(*evaluate(model, coordinates, energies, _species))
optimize_step(a) optimize_step(a)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', aev_computer.timers['radial terms']) print('Radial terms:', model.aev_computer.timers['radial terms'])
print('Angular terms:', aev_computer.timers['angular terms']) print('Angular terms:', model.aev_computer.timers['angular terms'])
print('Terms and indices:', aev_computer.timers['terms and indices']) print('Terms and indices:', model.aev_computer.timers['terms and indices'])
print('Combinations:', aev_computer.timers['combinations']) print('Combinations:', model.aev_computer.timers['combinations'])
print('Mask R:', aev_computer.timers['mask_r']) print('Mask R:', model.aev_computer.timers['mask_r'])
print('Mask A:', aev_computer.timers['mask_a']) print('Mask A:', model.aev_computer.timers['mask_a'])
print('Assemble:', aev_computer.timers['assemble']) print('Assemble:', model.aev_computer.timers['assemble'])
print('Total AEV:', aev_computer.timers['total']) print('Total AEV:', model.aev_computer.timers['total'])
print('NN:', model.timers['nn']) print('NN:', model.timers['nn'])
print('Total Forward:', model.timers['forward']) print('Total Forward:', model.timers['forward'])
print('Total Backward:', timer['backward']) print('Total Backward:', timer['backward'])
......
import torch import torch
import numpy
import torchani import torchani
import unittest import unittest
import os import os
...@@ -15,7 +14,8 @@ class TestAEV(unittest.TestCase): ...@@ -15,7 +14,8 @@ class TestAEV(unittest.TestCase):
self.aev = torchani.SortedAEV(dtype=dtype, device=torch.device('cpu')) self.aev = torchani.SortedAEV(dtype=dtype, device=torch.device('cpu'))
self.tolerance = 1e-5 self.tolerance = 1e-5
def _test_molecule(self, coordinates, species, expected_radial, expected_angular): def _test_molecule(self, coordinates, species, expected_radial,
expected_angular):
radial, angular = self.aev(coordinates, species) radial, angular = self.aev(coordinates, species)
radial_diff = expected_radial - radial radial_diff = expected_radial - radial
radial_max_error = torch.max(torch.abs(radial_diff)).item() radial_max_error = torch.max(torch.abs(radial_diff)).item()
......
...@@ -6,7 +6,8 @@ import copy ...@@ -6,7 +6,8 @@ import copy
class TestBenchmark(unittest.TestCase): class TestBenchmark(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, device=torchani.default_device): def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.conformations = 100 self.conformations = 100
......
import torchani import torchani
import unittest import unittest
import copy
import tempfile import tempfile
import os import os
import torch import torch
...@@ -28,7 +27,8 @@ class TestDataset(unittest.TestCase): ...@@ -28,7 +27,8 @@ class TestDataset(unittest.TestCase):
l2 = 0 l2 = 0
for f in os.listdir(self.data_path): for f in os.listdir(self.data_path):
f = os.path.join(self.data_path, f) f = os.path.join(self.data_path, f)
if os.path.isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')): if os.path.isfile(f) and \
(f.endswith('.h5') or f.endswith('.hdf5')):
for j in pyanitools.anidataloader(f): for j in pyanitools.anidataloader(f):
l2 += j['energies'].shape[0] l2 += j['energies'].shape[0]
# compute data length using iterator # compute data length using iterator
...@@ -46,7 +46,8 @@ class TestDataset(unittest.TestCase): ...@@ -46,7 +46,8 @@ class TestDataset(unittest.TestCase):
l2 = 0 l2 = 0
for f in os.listdir(self.data_path): for f in os.listdir(self.data_path):
f = os.path.join(self.data_path, f) f = os.path.join(self.data_path, f)
if os.path.isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')): if os.path.isfile(f) and \
(f.endswith('.h5') or f.endswith('.hdf5')):
for j in pyanitools.anidataloader(f): for j in pyanitools.anidataloader(f):
conformations = j['energies'].shape[0] conformations = j['energies'].shape[0]
l2 += ceil(conformations / chunksize) l2 += ceil(conformations / chunksize)
...@@ -102,12 +103,12 @@ class TestDataset(unittest.TestCase): ...@@ -102,12 +103,12 @@ class TestDataset(unittest.TestCase):
def _testMolSizes(self, ds): def _testMolSizes(self, ds):
for i in range(len(ds)): for i in range(len(ds)):
l = bisect(ds.cumulative_sizes, i) left = bisect(ds.cumulative_sizes, i)
moli = ds[i][0].item() moli = ds[i][0].item()
for j in range(len(ds)): for j in range(len(ds)):
l2 = bisect(ds.cumulative_sizes, j) left2 = bisect(ds.cumulative_sizes, j)
molj = ds[j][0].item() molj = ds[j][0].item()
if l == l2: if left == left2:
self.assertEqual(moli, molj) self.assertEqual(moli, molj)
else: else:
if moli == molj: if moli == molj:
......
import torch import torch
import numpy
import torchani import torchani
import unittest import unittest
import os import os
...@@ -12,7 +11,8 @@ N = 97 ...@@ -12,7 +11,8 @@ N = 97
class TestEnergies(unittest.TestCase): class TestEnergies(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, device=torchani.default_device): def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
self.tolerance = 5e-5 self.tolerance = 5e-5
self.aev_computer = torchani.SortedAEV( self.aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu')) dtype=dtype, device=torch.device('cpu'))
......
import torch import torch
import numpy
import torchani import torchani
import unittest import unittest
import os import os
...@@ -11,7 +10,8 @@ N = 97 ...@@ -11,7 +10,8 @@ N = 97
class TestForce(unittest.TestCase): class TestForce(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, device=torchani.default_device): def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device):
self.tolerance = 1e-5 self.tolerance = 1e-5
self.aev_computer = torchani.SortedAEV( self.aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu')) dtype=dtype, device=torch.device('cpu'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment