Unverified Commit 50fb6cdb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Modify training related APIs to batch the whole data by padding (#63)

parent b7cab4f1
......@@ -9,11 +9,10 @@ sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.AEVComputer(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
shift_energy = torchani.EnergyShifter(aev_computer.species, sae_file)
model = torch.nn.Sequential(prepare, aev_computer, nn, shift_energy)
model = torch.nn.Sequential(aev_computer, nn, shift_energy)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
......@@ -21,7 +20,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H']
species = torch.LongTensor([[2, 1, 1, 1, 1]]) # 1 = H, 2 = C, 3 = N, 4 = O
_, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
......
......@@ -19,7 +19,6 @@ def atomic():
def get_or_create_model(filename, benchmark=False,
device=torch.device('cpu')):
aev_computer = torchani.AEVComputer(benchmark=benchmark)
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.CustomModel(
benchmark=benchmark,
per_species={
......@@ -34,7 +33,7 @@ def get_or_create_model(filename, benchmark=False,
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, model, Flatten())
model = torch.nn.Sequential(aev_computer, model, Flatten())
if os.path.isfile(filename):
model.load_state_dict(torch.load(filename))
else:
......
......@@ -14,12 +14,9 @@ parser.add_argument('dataset_path',
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=1024, type=int)
parser.add_argument('--const_file',
help='File storing constants',
default=torchani.buildin_const_file)
......@@ -37,12 +34,11 @@ parser = parser.parse_args()
# load modules and datasets
device = torch.device(parser.device)
aev_computer = torchani.AEVComputer(const_file=parser.const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species,
from_=parser.network_dir,
ensemble=parser.ensemble)
model = torch.nn.Sequential(prepare, aev_computer, nn)
container = torchani.ignite.Container({'energies': model})
model = torch.nn.Sequential(aev_computer, nn)
container = torchani.training.Container({'energies': model})
container = container.to(device)
# load datasets
......@@ -50,9 +46,9 @@ shift_energy = torchani.EnergyShifter(aev_computer.species, parser.sae_file)
if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path):
dataset = torchani.data.ANIDataset(
parser.dataset_path, parser.chunk_size, device=device,
transform=[shift_energy.subtract_from_dataset])
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, aev_computer.species, parser.batch_size,
device=device, transform=[shift_energy.subtract_from_dataset])
datasets = [dataset]
else:
with open(parser.dataset_path, 'rb') as f:
......@@ -67,11 +63,10 @@ def hartree2kcal(x):
for dataset in datasets:
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
'RMSE': torchani.training.RMSEMetric('energies')
})
evaluator.run(dataloader)
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
print(rmse, 'kcal/mol')
......@@ -29,12 +29,9 @@ parser.add_argument('--training_rmse_every',
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=1024, type=int)
parser.add_argument('--log',
help='Log directory for tensorboardX',
default=None)
......@@ -56,21 +53,20 @@ start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint,
True, device=device)
training, validation, testing = torchani.data.load_or_create(
parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
device=device, transform=[shift_energy.subtract_from_dataset])
training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, parser.batch_chunks)
container = torchani.ignite.Container({'energies': nnp})
training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, nnp[0].species,
parser.dataset_path, device=device,
transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp})
parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
container, optimizer, torchani.training.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
'RMSE': torchani.training.RMSEMetric('energies')
})
......
......@@ -14,32 +14,28 @@ parser.add_argument('dataset_path',
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=1024, type=int)
parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
True, device=device)
dataset = torchani.data.ANIDataset(
parser.dataset_path, parser.chunk_size, device=device,
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
container = torchani.ignite.Container({'energies': nnp})
container = torchani.training.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
container, optimizer, torchani.training.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataloader), desc='epoch')
trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
......@@ -54,15 +50,15 @@ def finalize_tqdm(trainer):
# run it!
start = timeit.default_timer()
trainer.run(dataloader, max_epochs=1)
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', nnp[1].timers['radial terms'])
print('Angular terms:', nnp[1].timers['angular terms'])
print('Terms and indices:', nnp[1].timers['terms and indices'])
print('Combinations:', nnp[1].timers['combinations'])
print('Mask R:', nnp[1].timers['mask_r'])
print('Mask A:', nnp[1].timers['mask_a'])
print('Assemble:', nnp[1].timers['assemble'])
print('Total AEV:', nnp[1].timers['total'])
print('NN:', nnp[2].timers['forward'])
print('Radial terms:', nnp[0].timers['radial terms'])
print('Angular terms:', nnp[0].timers['angular terms'])
print('Terms and indices:', nnp[0].timers['terms and indices'])
print('Combinations:', nnp[0].timers['combinations'])
print('Mask R:', nnp[0].timers['mask_r'])
print('Mask A:', nnp[0].timers['mask_a'])
print('Assemble:', nnp[0].timers['assemble'])
print('Total AEV:', nnp[0].timers['total'])
print('NN:', nnp[1].timers['forward'])
print('Epoch time:', elapsed)
......@@ -13,11 +13,6 @@ class TestAEV(unittest.TestCase):
def setUp(self):
self.aev_computer = torchani.AEVComputer()
self.radial_length = self.aev_computer.radial_length
self.prepare = torchani.PrepareInput(self.aev_computer.species)
self.aev = torch.nn.Sequential(
self.prepare,
self.aev_computer
)
self.tolerance = 1e-5
def _assertAEVEqual(self, expected_radial, expected_angular, aev):
......@@ -36,7 +31,7 @@ class TestAEV(unittest.TestCase):
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _ \
= pickle.load(f)
_, aev = self.aev((species, coordinates))
_, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev)
def testPadding(self):
......@@ -46,8 +41,7 @@ class TestAEV(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append(
self.prepare((species, coordinates)))
species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular))
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
......
......@@ -8,7 +8,8 @@ class TestBenchmark(unittest.TestCase):
def setUp(self):
self.conformations = 100
self.species = list('HHCCNNOO')
self.species = torch.randint(4, (self.conformations, 8),
dtype=torch.long)
self.coordinates = torch.randn(self.conformations, 8, 3)
self.count = 100
......@@ -79,9 +80,7 @@ class TestBenchmark(unittest.TestCase):
def testAEV(self):
aev_computer = torchani.AEVComputer(benchmark=True)
prepare = torchani.PrepareInput(aev_computer.species)
run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [
self._testModule(aev_computer, aev_computer, [
'terms and indices>radial terms',
'terms and indices>angular terms',
'total>terms and indices',
......@@ -91,10 +90,9 @@ class TestBenchmark(unittest.TestCase):
def testANIModel(self):
aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species)
model = torchani.models.NeuroChemNNP(aev_computer.species,
benchmark=True)
run_module = torch.nn.Sequential(prepare, aev_computer, model)
run_module = torch.nn.Sequential(aev_computer, model)
self._testModule(run_module, model, ['forward'])
......
import os
import torchani
import unittest
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
class TestDataset(unittest.TestCase):
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize)
for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self):
self._test_chunksize(64)
def testChunk128(self):
self._test_chunksize(128)
def testChunk32(self):
self._test_chunksize(32)
def testChunk256(self):
self._test_chunksize(256)
dataset_path = os.path.join(path, '../dataset')
print(dataset_path)
batch_size = 256
aev = torchani.AEVComputer()
class TestData(unittest.TestCase):
def testTensorShape(self):
ds = torchani.training.BatchedANIDataset(dataset_path, aev.species,
batch_size)
for i in ds:
input, output = i
species, coordinates = input
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
self.assertEqual(len(coordinates.shape), 3)
self.assertEqual(coordinates.shape[2], 3)
self.assertEqual(coordinates.shape[1], coordinates.shape[1])
self.assertEqual(coordinates.shape[0], coordinates.shape[0])
self.assertEqual(len(energies.shape), 1)
self.assertEqual(coordinates.shape[0], energies.shape[0])
if __name__ == '__main__':
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment