Unverified Commit f146feca authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Refactor codes to put neurochem related codes together (#72)

parent 85a6dd1e
......@@ -17,6 +17,7 @@ steps:
UnitTests:
image: '${{BuildTorchANI}}'
commands:
- find . -name '*.pyc' -delete
- python setup.py test
# - python2 setup.py test
......
......@@ -7,11 +7,14 @@ path = os.path.dirname(os.path.realpath(__file__))
const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat') # noqa: E501
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
ensemble = 8
aev_computer = torchani.AEVComputer(const_file=const_file)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
shift_energy = torchani.EnergyShifter(aev_computer.species, sae_file)
consts = torchani.neurochem.Constants(const_file)
sae = torchani.neurochem.load_sae(sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, from_=network_dir,
ensemble=ensemble)
shift_energy = torchani.EnergyShifter(consts.species, sae)
model = torch.nn.Sequential(aev_computer, nn, shift_energy)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
......@@ -20,7 +23,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True)
species = torch.LongTensor([[1, 0, 0, 0, 0]]) # 0 = H, 1 = C, 2 = N, 3 = O
species = consts.species_to_tensor('CHHHH', device).unsqueeze(0)
_, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
......
......@@ -16,17 +16,16 @@ def atomic():
return model
def get_or_create_model(filename, benchmark=False,
device=torch.device('cpu')):
aev_computer = torchani.AEVComputer(benchmark=benchmark)
model = torchani.models.CustomModel(
benchmark=benchmark,
per_species={
'C': atomic(),
'H': atomic(),
'N': atomic(),
'O': atomic(),
})
def get_or_create_model(filename, device=torch.device('cpu')):
consts = torchani.neurochem.Constants()
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
model = torchani.ANIModel([
('C', atomic()),
('H', atomic()),
('N', atomic()),
('O', atomic()),
])
class Flatten(torch.nn.Module):
......@@ -38,4 +37,4 @@ def get_or_create_model(filename, benchmark=False,
model.load_state_dict(torch.load(filename))
else:
torch.save(model.state_dict(), filename)
return model.to(device), torchani.EnergyShifter(aev_computer.species)
return model.to(device), torchani.EnergyShifter(consts.species, sae)
......@@ -19,10 +19,10 @@ parser.add_argument('--batch_size',
default=1024, type=int)
parser.add_argument('--const_file',
help='File storing constants',
default=torchani.buildin_const_file)
default=torchani.neurochem.buildin_const_file)
parser.add_argument('--sae_file',
help='File storing self atomic energies',
default=torchani.buildin_sae_file)
default=torchani.neurochem.buildin_sae_file)
parser.add_argument('--network_dir',
help='Directory or prefix of directories storing networks',
default=None)
......@@ -33,8 +33,10 @@ parser = parser.parse_args()
# load modules and datasets
device = torch.device(parser.device)
aev_computer = torchani.AEVComputer(const_file=parser.const_file)
nn = torchani.models.NeuroChemNNP(aev_computer.species,
consts = torchani.neurochem.Constants(parser.const_file)
sae = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species,
from_=parser.network_dir,
ensemble=parser.ensemble)
model = torch.nn.Sequential(aev_computer, nn)
......@@ -42,12 +44,12 @@ container = torchani.training.Container({'energies': model})
container = container.to(device)
# load datasets
shift_energy = torchani.EnergyShifter(aev_computer.species, parser.sae_file)
shift_energy = torchani.EnergyShifter(consts.species, sae)
if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path):
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, aev_computer.species, parser.batch_size,
parser.dataset_path, consts.species, parser.batch_size,
device=device, transform=[shift_energy.subtract_from_dataset])
datasets = [dataset]
else:
......
......@@ -52,7 +52,7 @@ writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint,
True, device=device)
device=device)
training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, nnp[0].species,
parser.dataset_path, device=device,
......
......@@ -21,8 +21,7 @@ parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
True, device=device)
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt', device=device)
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
......@@ -48,17 +47,47 @@ def finalize_tqdm(trainer):
trainer.state.tqdm.close()
timers = {}
def time_func(key, func):
timers[key] = 0
def wrapper(*args, **kwargs):
start = timeit.default_timer()
ret = func(*args, **kwargs)
end = timeit.default_timer()
timers[key] += end - start
return ret
return wrapper
# enable timers
nnp[0].radial_subaev_terms = time_func('radial terms',
nnp[0].radial_subaev_terms)
nnp[0].angular_subaev_terms = time_func('angular terms',
nnp[0].angular_subaev_terms)
nnp[0].terms_and_indices = time_func('terms and indices',
nnp[0].terms_and_indices)
nnp[0].combinations = time_func('combinations', nnp[0].combinations)
nnp[0].compute_mask_r = time_func('mask_r', nnp[0].compute_mask_r)
nnp[0].compute_mask_a = time_func('mask_a', nnp[0].compute_mask_a)
nnp[0].assemble = time_func('assemble', nnp[0].assemble)
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)
# run it!
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', nnp[0].timers['radial terms'])
print('Angular terms:', nnp[0].timers['angular terms'])
print('Terms and indices:', nnp[0].timers['terms and indices'])
print('Combinations:', nnp[0].timers['combinations'])
print('Mask R:', nnp[0].timers['mask_r'])
print('Mask A:', nnp[0].timers['mask_a'])
print('Assemble:', nnp[0].timers['assemble'])
print('Total AEV:', nnp[0].timers['total'])
print('NN:', nnp[1].timers['forward'])
print('Radial terms:', timers['radial terms'])
print('Angular terms:', timers['angular terms'])
print('Terms and indices:', timers['terms and indices'])
print('Combinations:', timers['combinations'])
print('Mask R:', timers['mask_r'])
print('Mask A:', timers['mask_a'])
print('Assemble:', timers['assemble'])
print('Total AEV:', timers['total'])
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
......@@ -11,8 +11,9 @@ N = 97
class TestAEV(unittest.TestCase):
def setUp(self):
self.aev_computer = torchani.AEVComputer()
self.radial_length = self.aev_computer.radial_length
self.constants = torchani.neurochem.Constants()
self.aev_computer = torchani.AEVComputer(**self.constants)
self.radial_length = self.aev_computer.radial_length()
self.tolerance = 1e-5
def _assertAEVEqual(self, expected_radial, expected_angular, aev):
......
import torch
import torchani
import unittest
import copy
class TestBenchmark(unittest.TestCase):
def setUp(self):
self.conformations = 100
self.species = torch.randint(4, (self.conformations, 8),
dtype=torch.long)
self.coordinates = torch.randn(self.conformations, 8, 3)
self.count = 100
def _testModule(self, run_module, result_module, asserts):
keys = []
for i in asserts:
if '>=' in i:
i = i.split('>=')
keys += [i[0].strip(), i[1].strip()]
elif '<=' in i:
i = i.split('<=')
keys += [i[0].strip(), i[1].strip()]
elif '>' in i:
i = i.split('>')
keys += [i[0].strip(), i[1].strip()]
elif '<' in i:
i = i.split('<')
keys += [i[0].strip(), i[1].strip()]
elif '=' in i:
i = i.split('=')
keys += [i[0].strip(), i[1].strip()]
else:
keys.append(i.strip())
self.assertEqual(set(result_module.timers.keys()), set(keys))
for i in keys:
self.assertEqual(result_module.timers[i], 0)
old_timers = copy.copy(result_module.timers)
for _ in range(self.count):
run_module((self.species, self.coordinates))
for i in keys:
self.assertLess(old_timers[i], result_module.timers[i])
for i in asserts:
if '>=' in i:
i = i.split('>=')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertGreaterEqual(
result_module.timers[key0], result_module.timers[key1])
elif '<=' in i:
i = i.split('<=')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertLessEqual(
result_module.timers[key0], result_module.timers[key1])
elif '>' in i:
i = i.split('>')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertGreater(
result_module.timers[key0], result_module.timers[key1])
elif '<' in i:
i = i.split('<')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertLess(result_module.timers[key0],
result_module.timers[key1])
elif '=' in i:
i = i.split('=')
key0 = i[0].strip()
key1 = i[1].strip()
self.assertEqual(result_module.timers[key0],
result_module.timers[key1])
old_timers = copy.copy(result_module.timers)
result_module.reset_timers()
self.assertEqual(set(result_module.timers.keys()), set(keys))
for i in keys:
self.assertEqual(result_module.timers[i], 0)
def testAEV(self):
aev_computer = torchani.AEVComputer(benchmark=True)
self._testModule(aev_computer, aev_computer, [
'terms and indices>radial terms',
'terms and indices>angular terms',
'total>terms and indices',
'total>combinations', 'total>assemble',
'total>mask_r', 'total>mask_a'
])
def testANIModel(self):
aev_computer = torchani.AEVComputer()
model = torchani.models.NeuroChemNNP(aev_computer.species,
benchmark=True)
run_module = torch.nn.Sequential(aev_computer, model)
self._testModule(run_module, model, ['forward'])
if __name__ == '__main__':
unittest.main()
......@@ -6,14 +6,14 @@ import unittest
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset')
batch_size = 256
aev = torchani.AEVComputer()
consts = torchani.neurochem.Constants()
class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.training.BatchedANIDataset(dataset_path,
aev.species,
consts.species,
batch_size)
def _assertTensorEqual(self, t1, t2):
......
......@@ -13,9 +13,11 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
aev_computer = torchani.AEVComputer()
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
consts = torchani.neurochem.Constants()
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
nnp = torchani.neurochem.load_model(consts.species)
shift_energy = torchani.EnergyShifter(consts.species, sae)
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
def testIsomers(self):
......
......@@ -16,13 +16,14 @@ class TestEnsemble(unittest.TestCase):
def _test_molecule(self, coordinates, species):
coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.AEVComputer()
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
n = torchani.neurochem.buildin_ensemble
prefix = torchani.neurochem.buildin_model_prefix
consts = torchani.neurochem.Constants()
aev = torchani.AEVComputer(**consts)
ensemble = torchani.neurochem.load_model(consts.species, ensemble=True)
ensemble = torch.nn.Sequential(aev, ensemble)
models = [torchani.models.
NeuroChemNNP(aev.species, ensemble=False,
models = [torchani.neurochem.load_model(
consts.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i))
for i in range(n)]
models = [torch.nn.Sequential(aev, m) for m in models]
......
......@@ -12,8 +12,9 @@ class TestForce(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-5
aev_computer = torchani.AEVComputer()
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
consts = torchani.neurochem.Constants()
aev_computer = torchani.AEVComputer(**consts)
nnp = torchani.neurochem.load_model(consts.species)
self.model = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self):
......
......@@ -15,11 +15,13 @@ threshold = 1e-5
class TestIgnite(unittest.TestCase):
def testIgnite(self):
aev_computer = torchani.AEVComputer()
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
consts = torchani.neurochem.Constants()
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
nnp = torchani.neurochem.load_model(consts.species)
shift_energy = torchani.EnergyShifter(consts.species, sae)
ds = torchani.training.BatchedANIDataset(
path, aev_computer.species, batchsize,
path, consts.species, batchsize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
......
......@@ -58,12 +58,13 @@ class NeuroChem (torchani.aev.AEVComputer):
energies, forces
aev = torchani.AEVComputer()
consts = torchani.neurochem.Constants()
aev = torchani.AEVComputer(**consts)
ncaev = NeuroChem().to(torch.device('cpu'))
mol_count = 0
species_indices = {aev.species[i]: i for i in range(len(aev.species))}
species_indices = {consts.species[i]: i for i in range(len(aev.species))}
for i in [1, 2, 3, 4]:
data_file = os.path.join(
path, '../dataset/ani_gdb_s0{}.h5'.format(i))
......
from .energyshifter import EnergyShifter
from . import models
from .models import ANIModel, Ensemble
from .aev import AEVComputer
from . import training
from . import padding
from .aev import AEVComputer
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble
from . import neurochem
__all__ = ['PrepareInput', 'AEVComputer', 'EnergyShifter',
'models', 'training', 'padding', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble']
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
'training', 'padding', 'neurochem']
# file for python 2 compatibility
import math
if not hasattr(math, 'inf'):
math.inf = float('inf')
import torch
import itertools
import math
from .env import buildin_const_file
from .benchmarked import BenchmarkedModule
from . import padding
class AEVComputerBase(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length',
'angular_sublength', 'angular_length', 'aev_length']
"""Base class of various implementations of AEV computer
Attributes
----------
benchmark : boolean
Whether to enable benchmark
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
Cutoff radius
EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
radial_sublength : int
The length of radial subaev of a single species
radial_length : int
The length of full radial aev
angular_sublength : int
The length of angular subaev of a single species
angular_length : int
The length of full angular aev
aev_length : int
The length of full aev
"""
def __init__(self, benchmark=False, const_file=buildin_const_file):
super(AEVComputerBase, self).__init__(benchmark)
self.const_file = const_file
# load constants from const file
const = {}
with open(const_file) as f:
for i in f:
try:
line = [x.strip() for x in i.split('=')]
name = line[0]
value = line[1]
if name == 'Rcr' or name == 'Rca':
setattr(self, name, float(value))
elif name in ['EtaR', 'ShfR', 'Zeta',
'ShfZ', 'EtaA', 'ShfA']:
value = [float(x.strip()) for x in value.replace(
'[', '').replace(']', '').split(',')]
value = torch.tensor(value)
const[name] = value
elif name == 'Atyp':
value = [x.strip() for x in value.replace(
'[', '').replace(']', '').split(',')]
self.species = value
except Exception:
raise ValueError('unable to parse const file')
# Compute lengths
self.radial_sublength = const['EtaR'].shape[0] * const['ShfR'].shape[0]
self.radial_length = len(self.species) * self.radial_sublength
self.angular_sublength = const['EtaA'].shape[0] * \
const['Zeta'].shape[0] * const['ShfA'].shape[0] * \
const['ShfZ'].shape[0]
species = len(self.species)
self.angular_length = int(
(species * (species + 1)) / 2) * self.angular_sublength
self.aev_length = self.radial_length + self.angular_length
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
const['EtaR'] = const['EtaR'].view(-1, 1)
const['ShfR'] = const['ShfR'].view(1, -1)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
const['EtaA'] = const['EtaA'].view(-1, 1, 1, 1)
const['Zeta'] = const['Zeta'].view(1, -1, 1, 1)
const['ShfA'] = const['ShfA'].view(1, 1, -1, 1)
const['ShfZ'] = const['ShfZ'].view(1, 1, 1, -1)
# register buffers
for i in const:
self.register_buffer(i, const[i])
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
Parameters
----------
(species, coordinates)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
Returns
-------
(torch.Tensor, torch.LongTensor)
Returns full AEV and species
"""
raise NotImplementedError('subclass must override this method')
def _cutoff_cosine(distances, cutoff):
"""Compute the elementwise cutoff cosine function
......@@ -136,35 +33,56 @@ def _cutoff_cosine(distances, cutoff):
)
class AEVComputer(AEVComputerBase):
"""The AEV computer assuming input coordinates sorted by species
class AEVComputer(torch.nn.Module):
"""AEV computer
Attributes
----------
timers : dict
Dictionary storing the the benchmark result. It has the following keys:
radial_subaev : time spent on computing radial subaev
angular_subaev : time spent on computing angular subaev
total : total time for computing everything.
filename : str
The name of the file that stores constant.
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
species : list(str)
Chemical symbols of supported atom types
"""
def __init__(self, benchmark=False, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark, const_file)
if benchmark:
self.radial_subaev_terms = self._enable_benchmark(
self.radial_subaev_terms, 'radial terms')
self.angular_subaev_terms = self._enable_benchmark(
self.angular_subaev_terms, 'angular terms')
self.terms_and_indices = self._enable_benchmark(
self.terms_and_indices, 'terms and indices')
self.combinations = self._enable_benchmark(
self.combinations, 'combinations')
self.compute_mask_r = self._enable_benchmark(
self.compute_mask_r, 'mask_r')
self.compute_mask_a = self._enable_benchmark(
self.compute_mask_a, 'mask_a')
self.assemble = self._enable_benchmark(self.assemble, 'assemble')
self.forward = self._enable_benchmark(self.forward, 'total')
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, species):
super(AEVComputer, self).__init__()
self.register_buffer('Rcr', Rcr)
self.register_buffer('Rca', Rca)
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1))
self.register_buffer('ShfR', ShfR.view(1, -1))
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self.register_buffer('EtaA', EtaA.view(-1, 1, 1, 1))
self.register_buffer('Zeta', Zeta.view(1, -1, 1, 1))
self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1))
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.species = species
def radial_sublength(self):
"""Returns the length of radial subaev of a single species"""
return self.EtaR.numel() * self.ShfR.numel()
def radial_length(self):
"""Returns the length of full radial aev"""
return len(self.species) * self.radial_sublength()
def angular_sublength(self):
"""Returns the length of angular subaev of a single species"""
return self.EtaA.numel() * self.Zeta.numel() * self.ShfA.numel() * \
self.ShfZ.numel()
def angular_length(self):
"""Returns the length of full angular aev"""
species = len(self.species)
return int((species * (species + 1)) / 2) * self.angular_sublength()
def aev_length(self):
"""Returns the length of full aev"""
return self.radial_length() + self.angular_length()
def radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors
......@@ -316,13 +234,13 @@ class AEVComputer(AEVComputerBase):
def combinations(self, tensor, dim=0):
n = tensor.shape[dim]
r = torch.arange(n).type(torch.long).to(tensor.device)
r = torch.arange(n, dtype=torch.long, device=tensor.device)
grid_x, grid_y = torch.meshgrid([r, r])
index1 = grid_y.masked_select(
torch.triu(torch.ones(n, n, device=self.EtaR.device),
torch.triu(torch.ones(n, n, device=tensor.device),
diagonal=1) == 1)
index2 = grid_x.masked_select(
torch.triu(torch.ones(n, n, device=self.EtaR.device),
torch.triu(torch.ones(n, n, device=tensor.device),
diagonal=1) == 1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
......@@ -427,7 +345,7 @@ class AEVComputer(AEVComputerBase):
angular_aevs = []
zero_angular_subaev = torch.zeros(
# TODO: can we make stack and cat broadcast?
conformations, atoms, self.angular_sublength,
conformations, atoms, self.angular_sublength(),
dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2):
......
import torch
import timeit
import functools
class BenchmarkedModule(torch.nn.Module):
"""Module with member function benchmarking support.
The benchmarking is done by wrapping the original member function with
a wrapped function. The wrapped function will call the original function,
and accumulate its running time into `self.timers`. Different accumulators
are distinguished by different keys. All times should have unit seconds.
To enable benchmarking for member functions in a subclass, simply
call the `__init__` of this class with `benchmark=True`, and add the
following code to your subclass's `__init__`:
```
if self.benchmark:
self._enable_benchmark(self.function_to_be_benchmarked, 'key1', 'key2')
```
Example
-------
The following code implements a subclass for timing the running time of
member function `f` and `g` and the total of these two::
```
class BenchmarkFG(BenchmarkedModule):
def __init__(self, benchmark=False)
super(BenchmarkFG, self).__init__(benchmark)
if benchmark:
self.f = self._enable_benchmark(self.f, 'function f', 'total')
self.g = self._enable_benchmark(self.g, 'function g', 'total')
def f(self):
print('in function f')
def g(self):
print('in function g')
```
Attributes
----------
benchmark : boolean
Whether benchmark is enabled
timers : dict
Dictionary storing the the benchmark result.
"""
def _enable_benchmark(self, fun, *keys):
"""Wrap a function to automatically benchmark it, and assign a key
for it.
Parameters
----------
keys
The keys in `self.timers` assigned. If multiple keys are specified,
then the time will be accumulated to all the keys.
func : function
The function to be benchmarked.
Returns
-------
function
Wrapped function that time the original function and update the
corresponding value in `self.timers` automatically.
"""
for key in keys:
self.timers[key] = 0
@functools.wraps(fun)
def wrapped(*args, **kwargs):
start = timeit.default_timer()
ret = fun(*args, **kwargs)
end = timeit.default_timer()
for key in keys:
self.timers[key] += end - start
return ret
return wrapped
def reset_timers(self):
"""Reset all timers. If benchmark is not enabled, a `ValueError`
will be raised."""
if not self.benchmark:
raise ValueError('Can not reset timers, benchmark not enabled')
for i in self.timers:
self.timers[i] = 0
def __init__(self, benchmark=False):
super(BenchmarkedModule, self).__init__()
self.benchmark = benchmark
if benchmark:
self.timers = {}
import torch
from .env import buildin_sae_file
class EnergyShifter(torch.nn.Module):
def __init__(self, species, self_energy_file=buildin_sae_file):
def __init__(self, species, self_energies):
super(EnergyShifter, self).__init__()
# load self energies
self.self_energies = {}
with open(self_energy_file) as f:
for i in f:
try:
line = [x.strip() for x in i.split('=')]
name = line[0].split(',')[0].strip()
value = float(line[1])
self.self_energies[name] = value
except Exception:
pass # ignore unrecognizable line
self_energies_tensor = [self.self_energies[s] for s in species]
self_energies_tensor = [self_energies[s] for s in species]
self.register_buffer('self_energies_tensor',
torch.tensor(self_energies_tensor,
dtype=torch.double))
......
import pkg_resources
buildin_const_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
buildin_sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
buildin_network_dir = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train0/networks/')
buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
buildin_ensemble = 8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment