Unverified Commit 50fb6cdb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Modify training related APIs to batch the whole data by padding (#63)

parent b7cab4f1
...@@ -9,11 +9,10 @@ sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit ...@@ -9,11 +9,10 @@ sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501 network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.AEVComputer(const_file=const_file) aev_computer = torchani.AEVComputer(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir, nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8) ensemble=8)
shift_energy = torchani.EnergyShifter(aev_computer.species, sae_file) shift_energy = torchani.EnergyShifter(aev_computer.species, sae_file)
model = torch.nn.Sequential(prepare, aev_computer, nn, shift_energy) model = torch.nn.Sequential(aev_computer, nn, shift_energy)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324], [-0.83140486, 0.39370209, -0.26395324],
...@@ -21,7 +20,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -21,7 +20,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881], [0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]], [0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True) requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H'] species = torch.LongTensor([[2, 1, 1, 1, 1]]) # 1 = H, 2 = C, 3 = N, 4 = O
_, energy = model((species, coordinates)) _, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0] derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
......
...@@ -19,7 +19,6 @@ def atomic(): ...@@ -19,7 +19,6 @@ def atomic():
def get_or_create_model(filename, benchmark=False, def get_or_create_model(filename, benchmark=False,
device=torch.device('cpu')): device=torch.device('cpu')):
aev_computer = torchani.AEVComputer(benchmark=benchmark) aev_computer = torchani.AEVComputer(benchmark=benchmark)
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.CustomModel( model = torchani.models.CustomModel(
benchmark=benchmark, benchmark=benchmark,
per_species={ per_species={
...@@ -34,7 +33,7 @@ def get_or_create_model(filename, benchmark=False, ...@@ -34,7 +33,7 @@ def get_or_create_model(filename, benchmark=False,
def forward(self, x): def forward(self, x):
return x[0], x[1].flatten() return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, model, Flatten()) model = torch.nn.Sequential(aev_computer, model, Flatten())
if os.path.isfile(filename): if os.path.isfile(filename):
model.load_state_dict(torch.load(filename)) model.load_state_dict(torch.load(filename))
else: else:
......
...@@ -14,12 +14,9 @@ parser.add_argument('dataset_path', ...@@ -14,12 +14,9 @@ parser.add_argument('dataset_path',
parser.add_argument('-d', '--device', parser.add_argument('-d', '--device',
help='Device of modules and tensors', help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu')) default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size', parser.add_argument('--batch_size',
help='Number of conformations of each chunk', help='Number of conformations of each batch',
default=256, type=int) default=1024, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--const_file', parser.add_argument('--const_file',
help='File storing constants', help='File storing constants',
default=torchani.buildin_const_file) default=torchani.buildin_const_file)
...@@ -37,12 +34,11 @@ parser = parser.parse_args() ...@@ -37,12 +34,11 @@ parser = parser.parse_args()
# load modules and datasets # load modules and datasets
device = torch.device(parser.device) device = torch.device(parser.device)
aev_computer = torchani.AEVComputer(const_file=parser.const_file) aev_computer = torchani.AEVComputer(const_file=parser.const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, nn = torchani.models.NeuroChemNNP(aev_computer.species,
from_=parser.network_dir, from_=parser.network_dir,
ensemble=parser.ensemble) ensemble=parser.ensemble)
model = torch.nn.Sequential(prepare, aev_computer, nn) model = torch.nn.Sequential(aev_computer, nn)
container = torchani.ignite.Container({'energies': model}) container = torchani.training.Container({'energies': model})
container = container.to(device) container = container.to(device)
# load datasets # load datasets
...@@ -50,9 +46,9 @@ shift_energy = torchani.EnergyShifter(aev_computer.species, parser.sae_file) ...@@ -50,9 +46,9 @@ shift_energy = torchani.EnergyShifter(aev_computer.species, parser.sae_file)
if parser.dataset_path.endswith('.h5') or \ if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \ parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path): os.path.isdir(parser.dataset_path):
dataset = torchani.data.ANIDataset( dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, parser.chunk_size, device=device, parser.dataset_path, aev_computer.species, parser.batch_size,
transform=[shift_energy.subtract_from_dataset]) device=device, transform=[shift_energy.subtract_from_dataset])
datasets = [dataset] datasets = [dataset]
else: else:
with open(parser.dataset_path, 'rb') as f: with open(parser.dataset_path, 'rb') as f:
...@@ -67,11 +63,10 @@ def hartree2kcal(x): ...@@ -67,11 +63,10 @@ def hartree2kcal(x):
for dataset in datasets: for dataset in datasets:
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies') 'RMSE': torchani.training.RMSEMetric('energies')
}) })
evaluator.run(dataloader) evaluator.run(dataset)
metrics = evaluator.state.metrics metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE']) rmse = hartree2kcal(metrics['RMSE'])
print(rmse, 'kcal/mol') print(rmse, 'kcal/mol')
...@@ -29,12 +29,9 @@ parser.add_argument('--training_rmse_every', ...@@ -29,12 +29,9 @@ parser.add_argument('--training_rmse_every',
parser.add_argument('-d', '--device', parser.add_argument('-d', '--device',
help='Device of modules and tensors', help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu')) default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size', parser.add_argument('--batch_size',
help='Number of conformations of each chunk', help='Number of conformations of each batch',
default=256, type=int) default=1024, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--log', parser.add_argument('--log',
help='Log directory for tensorboardX', help='Log directory for tensorboardX',
default=None) default=None)
...@@ -56,21 +53,20 @@ start = timeit.default_timer() ...@@ -56,21 +53,20 @@ start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint, nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint,
True, device=device) True, device=device)
training, validation, testing = torchani.data.load_or_create( training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size, parser.dataset_checkpoint, parser.batch_size, nnp[0].species,
device=device, transform=[shift_energy.subtract_from_dataset]) parser.dataset_path, device=device,
training = torchani.data.dataloader(training, parser.batch_chunks) transform=[shift_energy.subtract_from_dataset])
validation = torchani.data.dataloader(validation, parser.batch_chunks) container = torchani.training.Container({'energies': nnp})
container = torchani.ignite.Container({'energies': nnp})
parser.optim_args = json.loads(parser.optim_args) parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer) optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args) optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies')) container, optimizer, torchani.training.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies') 'RMSE': torchani.training.RMSEMetric('energies')
}) })
......
...@@ -14,32 +14,28 @@ parser.add_argument('dataset_path', ...@@ -14,32 +14,28 @@ parser.add_argument('dataset_path',
parser.add_argument('-d', '--device', parser.add_argument('-d', '--device',
help='Device of modules and tensors', help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu')) default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size', parser.add_argument('--batch_size',
help='Number of conformations of each chunk', help='Number of conformations of each batch',
default=256, type=int) default=1024, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser = parser.parse_args() parser = parser.parse_args()
# set up benchmark # set up benchmark
device = torch.device(parser.device) device = torch.device(parser.device)
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt', nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
True, device=device) True, device=device)
dataset = torchani.data.ANIDataset( dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, parser.chunk_size, device=device, parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset])
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks) container = torchani.training.Container({'energies': nnp})
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies')) container, optimizer, torchani.training.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer): def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataloader), desc='epoch') trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED) @trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
...@@ -54,15 +50,15 @@ def finalize_tqdm(trainer): ...@@ -54,15 +50,15 @@ def finalize_tqdm(trainer):
# run it! # run it!
start = timeit.default_timer() start = timeit.default_timer()
trainer.run(dataloader, max_epochs=1) trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', nnp[1].timers['radial terms']) print('Radial terms:', nnp[0].timers['radial terms'])
print('Angular terms:', nnp[1].timers['angular terms']) print('Angular terms:', nnp[0].timers['angular terms'])
print('Terms and indices:', nnp[1].timers['terms and indices']) print('Terms and indices:', nnp[0].timers['terms and indices'])
print('Combinations:', nnp[1].timers['combinations']) print('Combinations:', nnp[0].timers['combinations'])
print('Mask R:', nnp[1].timers['mask_r']) print('Mask R:', nnp[0].timers['mask_r'])
print('Mask A:', nnp[1].timers['mask_a']) print('Mask A:', nnp[0].timers['mask_a'])
print('Assemble:', nnp[1].timers['assemble']) print('Assemble:', nnp[0].timers['assemble'])
print('Total AEV:', nnp[1].timers['total']) print('Total AEV:', nnp[0].timers['total'])
print('NN:', nnp[2].timers['forward']) print('NN:', nnp[1].timers['forward'])
print('Epoch time:', elapsed) print('Epoch time:', elapsed)
...@@ -13,11 +13,6 @@ class TestAEV(unittest.TestCase): ...@@ -13,11 +13,6 @@ class TestAEV(unittest.TestCase):
def setUp(self): def setUp(self):
self.aev_computer = torchani.AEVComputer() self.aev_computer = torchani.AEVComputer()
self.radial_length = self.aev_computer.radial_length self.radial_length = self.aev_computer.radial_length
self.prepare = torchani.PrepareInput(self.aev_computer.species)
self.aev = torch.nn.Sequential(
self.prepare,
self.aev_computer
)
self.tolerance = 1e-5 self.tolerance = 1e-5
def _assertAEVEqual(self, expected_radial, expected_angular, aev): def _assertAEVEqual(self, expected_radial, expected_angular, aev):
...@@ -36,7 +31,7 @@ class TestAEV(unittest.TestCase): ...@@ -36,7 +31,7 @@ class TestAEV(unittest.TestCase):
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _ \ coordinates, species, expected_radial, expected_angular, _, _ \
= pickle.load(f) = pickle.load(f)
_, aev = self.aev((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev) self._assertAEVEqual(expected_radial, expected_angular, aev)
def testPadding(self): def testPadding(self):
...@@ -46,8 +41,7 @@ class TestAEV(unittest.TestCase): ...@@ -46,8 +41,7 @@ class TestAEV(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i)) datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, radial, angular, _, _ = pickle.load(f) coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append( species_coordinates.append((species, coordinates))
self.prepare((species, coordinates)))
radial_angular.append((radial, angular)) radial_angular.append((radial, angular))
species, coordinates = torchani.padding.pad_and_batch( species, coordinates = torchani.padding.pad_and_batch(
species_coordinates) species_coordinates)
......
...@@ -8,7 +8,8 @@ class TestBenchmark(unittest.TestCase): ...@@ -8,7 +8,8 @@ class TestBenchmark(unittest.TestCase):
def setUp(self): def setUp(self):
self.conformations = 100 self.conformations = 100
self.species = list('HHCCNNOO') self.species = torch.randint(4, (self.conformations, 8),
dtype=torch.long)
self.coordinates = torch.randn(self.conformations, 8, 3) self.coordinates = torch.randn(self.conformations, 8, 3)
self.count = 100 self.count = 100
...@@ -79,9 +80,7 @@ class TestBenchmark(unittest.TestCase): ...@@ -79,9 +80,7 @@ class TestBenchmark(unittest.TestCase):
def testAEV(self): def testAEV(self):
aev_computer = torchani.AEVComputer(benchmark=True) aev_computer = torchani.AEVComputer(benchmark=True)
prepare = torchani.PrepareInput(aev_computer.species) self._testModule(aev_computer, aev_computer, [
run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [
'terms and indices>radial terms', 'terms and indices>radial terms',
'terms and indices>angular terms', 'terms and indices>angular terms',
'total>terms and indices', 'total>terms and indices',
...@@ -91,10 +90,9 @@ class TestBenchmark(unittest.TestCase): ...@@ -91,10 +90,9 @@ class TestBenchmark(unittest.TestCase):
def testANIModel(self): def testANIModel(self):
aev_computer = torchani.AEVComputer() aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.NeuroChemNNP(aev_computer.species, model = torchani.models.NeuroChemNNP(aev_computer.species,
benchmark=True) benchmark=True)
run_module = torch.nn.Sequential(prepare, aev_computer, model) run_module = torch.nn.Sequential(aev_computer, model)
self._testModule(run_module, model, ['forward']) self._testModule(run_module, model, ['forward'])
......
import os import os
import torchani
import unittest import unittest
import torchani.data
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset') dataset_path = os.path.join(path, '../dataset')
print(dataset_path)
batch_size = 256
class TestDataset(unittest.TestCase): aev = torchani.AEVComputer()
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize) class TestData(unittest.TestCase):
for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize) def testTensorShape(self):
ds = torchani.training.BatchedANIDataset(dataset_path, aev.species,
def testChunk64(self): batch_size)
self._test_chunksize(64) for i in ds:
input, output = i
def testChunk128(self): species, coordinates = input
self._test_chunksize(128) energies = output['energies']
self.assertEqual(len(species.shape), 2)
def testChunk32(self): self.assertLessEqual(species.shape[0], batch_size)
self._test_chunksize(32) self.assertEqual(len(coordinates.shape), 3)
self.assertEqual(coordinates.shape[2], 3)
def testChunk256(self): self.assertEqual(coordinates.shape[1], coordinates.shape[1])
self._test_chunksize(256) self.assertEqual(coordinates.shape[0], coordinates.shape[0])
self.assertEqual(len(energies.shape), 1)
self.assertEqual(coordinates.shape[0], energies.shape[0])
if __name__ == '__main__': if __name__ == '__main__':
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment