"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "038ed9999b897625b708cbb2591dc702dc39e513"
Unverified Commit e2e740d0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

completely decouple aev computer and neural networks (#40)

parent ee31b752
...@@ -9,8 +9,10 @@ sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit ...@@ -9,8 +9,10 @@ sae_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/sae_linfit
network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501 network_dir = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/train') # noqa: E501
aev_computer = torchani.SortedAEV(const_file=const_file, device=device) aev_computer = torchani.SortedAEV(const_file=const_file, device=device)
nn = torchani.models.NeuroChemNNP(aev_computer, derivative=True, prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device)
from_=network_dir, ensemble=8) nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
model = torch.nn.Sequential(prepare, aev_computer, nn)
shift_energy = torchani.EnergyShifter(sae_file) shift_energy = torchani.EnergyShifter(sae_file)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
...@@ -23,7 +25,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -23,7 +25,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
requires_grad=True) requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H'] species = ['C', 'H', 'H', 'H', 'H']
energy = nn(coordinates, species) energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0] derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
energy = shift_energy.add_sae(energy, species) energy = shift_energy.add_sae(energy, species)
force = -derivative force = -derivative
......
...@@ -9,9 +9,8 @@ def celu(x, alpha): ...@@ -9,9 +9,8 @@ def celu(x, alpha):
class AtomicNetwork(torch.nn.Module): class AtomicNetwork(torch.nn.Module):
def __init__(self, aev_computer): def __init__(self):
super(AtomicNetwork, self).__init__() super(AtomicNetwork, self).__init__()
self.aev_computer = aev_computer
self.output_length = 1 self.output_length = 1
self.layer1 = torch.nn.Linear(384, 128) self.layer1 = torch.nn.Linear(384, 128)
self.layer2 = torch.nn.Linear(128, 128) self.layer2 = torch.nn.Linear(128, 128)
...@@ -33,16 +32,23 @@ class AtomicNetwork(torch.nn.Module): ...@@ -33,16 +32,23 @@ class AtomicNetwork(torch.nn.Module):
def get_or_create_model(filename, benchmark=False, def get_or_create_model(filename, benchmark=False,
device=torchani.default_device): device=torchani.default_device):
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device) aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
prepare = torchani.PrepareInput(aev_computer.species, aev_computer.device)
model = torchani.models.CustomModel( model = torchani.models.CustomModel(
aev_computer,
reducer=torch.sum, reducer=torch.sum,
benchmark=benchmark, benchmark=benchmark,
per_species={ per_species={
'C': AtomicNetwork(aev_computer), 'C': AtomicNetwork(),
'H': AtomicNetwork(aev_computer), 'H': AtomicNetwork(),
'N': AtomicNetwork(aev_computer), 'N': AtomicNetwork(),
'O': AtomicNetwork(aev_computer), 'O': AtomicNetwork(),
}) })
class Flatten(torch.nn.Module):
def forward(self, x):
return x.flatten()
model = torch.nn.Sequential(prepare, aev_computer, model, Flatten())
if os.path.isfile(filename): if os.path.isfile(filename):
model.load_state_dict(torch.load(filename)) model.load_state_dict(torch.load(filename))
else: else:
......
...@@ -24,23 +24,11 @@ training, validation, testing = torchani.data.load_or_create( ...@@ -24,23 +24,11 @@ training, validation, testing = torchani.data.load_or_create(
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks) training = torchani.data.dataloader(training, batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks) validation = torchani.data.dataloader(validation, batch_chunks)
nnp = model.get_or_create_model(model_checkpoint) nnp = model.get_or_create_model(model_checkpoint)
batch_nnp = torchani.models.BatchModel(nnp)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
batch_nnp = torchani.models.BatchModel(Flatten(nnp))
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss) container, optimizer, torchani.ignite.energy_mse_loss)
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
......
...@@ -15,19 +15,7 @@ dataset = torchani.data.ANIDataset( ...@@ -15,19 +15,7 @@ dataset = torchani.data.ANIDataset(
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks) dataloader = torchani.data.dataloader(dataset, batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True) nnp = model.get_or_create_model('/tmp/model.pt', True)
batch_nnp = torchani.models.BatchModel(nnp)
class Flatten(torch.nn.Module):
def __init__(self, model):
super(Flatten, self).__init__()
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
batch_nnp = torchani.models.BatchModel(Flatten(nnp))
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
...@@ -53,14 +41,13 @@ def finalize_tqdm(trainer): ...@@ -53,14 +41,13 @@ def finalize_tqdm(trainer):
start = timeit.default_timer() start = timeit.default_timer()
trainer.run(dataloader, max_epochs=1) trainer.run(dataloader, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', nnp.aev_computer.timers['radial terms']) print('Radial terms:', nnp[1].timers['radial terms'])
print('Angular terms:', nnp.aev_computer.timers['angular terms']) print('Angular terms:', nnp[1].timers['angular terms'])
print('Terms and indices:', nnp.aev_computer.timers['terms and indices']) print('Terms and indices:', nnp[1].timers['terms and indices'])
print('Combinations:', nnp.aev_computer.timers['combinations']) print('Combinations:', nnp[1].timers['combinations'])
print('Mask R:', nnp.aev_computer.timers['mask_r']) print('Mask R:', nnp[1].timers['mask_r'])
print('Mask A:', nnp.aev_computer.timers['mask_a']) print('Mask A:', nnp[1].timers['mask_a'])
print('Assemble:', nnp.aev_computer.timers['assemble']) print('Assemble:', nnp[1].timers['assemble'])
print('Total AEV:', nnp.aev_computer.timers['total']) print('Total AEV:', nnp[1].timers['total'])
print('NN:', nnp.timers['nn']) print('NN:', nnp[2].timers['forward'])
print('Total Forward:', nnp.timers['forward'])
print('Epoch time:', elapsed) print('Epoch time:', elapsed)
...@@ -11,15 +11,20 @@ N = 97 ...@@ -11,15 +11,20 @@ N = 97
class TestAEV(unittest.TestCase): class TestAEV(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype): def setUp(self, dtype=torchani.default_dtype):
self.aev = torchani.SortedAEV(dtype=dtype, device=torch.device('cpu')) aev_computer = torchani.SortedAEV(dtype=dtype,
device=torch.device('cpu'))
self.radial_length = aev_computer.radial_length
self.aev = torch.nn.Sequential(
torchani.PrepareInput(aev_computer.species, aev_computer.device),
aev_computer
)
self.tolerance = 1e-5 self.tolerance = 1e-5
def _test_molecule(self, coordinates, species, expected_radial, def _test_molecule(self, coordinates, species, expected_radial,
expected_angular): expected_angular):
species = self.aev.species_to_tensor(species) species, aev = self.aev((species, coordinates))
aev = self.aev((coordinates, species)) radial = aev[..., :self.radial_length]
radial = aev[..., :self.aev.radial_length] angular = aev[..., self.radial_length:]
angular = aev[..., self.aev.radial_length:]
radial_diff = expected_radial - radial radial_diff = expected_radial - radial
radial_max_error = torch.max(torch.abs(radial_diff)).item() radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular angular_diff = expected_angular - angular
......
...@@ -21,8 +21,11 @@ if sys.version_info.major >= 3: ...@@ -21,8 +21,11 @@ if sys.version_info.major >= 3:
ds = torchani.data.ANIDataset(path, chunksize, device=device) ds = torchani.data.ANIDataset(path, chunksize, device=device)
loader = torchani.data.dataloader(ds, batch_chunks) loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer) prepare = torchani.PrepareInput(aev_computer.species,
batch_nnp = torchani.models.BatchModel(nnp) aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
model = torch.nn.Sequential(prepare, aev_computer, nnp)
batch_nnp = torchani.models.BatchModel(model)
for batch_input, batch_output in itertools.islice(loader, 10): for batch_input, batch_output in itertools.islice(loader, 10):
batch_output_ = batch_nnp(batch_input).squeeze() batch_output_ = batch_nnp(batch_input).squeeze()
self.assertListEqual(list(batch_output_.shape), self.assertListEqual(list(batch_output_.shape),
......
...@@ -16,7 +16,7 @@ class TestBenchmark(unittest.TestCase): ...@@ -16,7 +16,7 @@ class TestBenchmark(unittest.TestCase):
self.conformations, 8, 3, dtype=dtype, device=device) self.conformations, 8, 3, dtype=dtype, device=device)
self.count = 100 self.count = 100
def _testModule(self, module, asserts): def _testModule(self, run_module, result_module, asserts):
keys = [] keys = []
for i in asserts: for i in asserts:
if '>=' in i: if '>=' in i:
...@@ -36,57 +36,57 @@ class TestBenchmark(unittest.TestCase): ...@@ -36,57 +36,57 @@ class TestBenchmark(unittest.TestCase):
keys += [i[0].strip(), i[1].strip()] keys += [i[0].strip(), i[1].strip()]
else: else:
keys.append(i.strip()) keys.append(i.strip())
self.assertEqual(set(module.timers.keys()), set(keys)) self.assertEqual(set(result_module.timers.keys()), set(keys))
for i in keys: for i in keys:
self.assertEqual(module.timers[i], 0) self.assertEqual(result_module.timers[i], 0)
old_timers = copy.copy(module.timers) old_timers = copy.copy(result_module.timers)
for _ in range(self.count): for _ in range(self.count):
if isinstance(module, torchani.aev.AEVComputer): run_module((self.species, self.coordinates))
species = module.species_to_tensor(self.species)
module((self.coordinates, species))
else:
module(self.coordinates, self.species)
for i in keys: for i in keys:
self.assertLess(old_timers[i], module.timers[i]) self.assertLess(old_timers[i], result_module.timers[i])
for i in asserts: for i in asserts:
if '>=' in i: if '>=' in i:
i = i.split('>=') i = i.split('>=')
key0 = i[0].strip() key0 = i[0].strip()
key1 = i[1].strip() key1 = i[1].strip()
self.assertGreaterEqual( self.assertGreaterEqual(
module.timers[key0], module.timers[key1]) result_module.timers[key0], result_module.timers[key1])
elif '<=' in i: elif '<=' in i:
i = i.split('<=') i = i.split('<=')
key0 = i[0].strip() key0 = i[0].strip()
key1 = i[1].strip() key1 = i[1].strip()
self.assertLessEqual( self.assertLessEqual(
module.timers[key0], module.timers[key1]) result_module.timers[key0], result_module.timers[key1])
elif '>' in i: elif '>' in i:
i = i.split('>') i = i.split('>')
key0 = i[0].strip() key0 = i[0].strip()
key1 = i[1].strip() key1 = i[1].strip()
self.assertGreater( self.assertGreater(
module.timers[key0], module.timers[key1]) result_module.timers[key0], result_module.timers[key1])
elif '<' in i: elif '<' in i:
i = i.split('<') i = i.split('<')
key0 = i[0].strip() key0 = i[0].strip()
key1 = i[1].strip() key1 = i[1].strip()
self.assertLess(module.timers[key0], module.timers[key1]) self.assertLess(result_module.timers[key0],
result_module.timers[key1])
elif '=' in i: elif '=' in i:
i = i.split('=') i = i.split('=')
key0 = i[0].strip() key0 = i[0].strip()
key1 = i[1].strip() key1 = i[1].strip()
self.assertEqual(module.timers[key0], module.timers[key1]) self.assertEqual(result_module.timers[key0],
old_timers = copy.copy(module.timers) result_module.timers[key1])
module.reset_timers() old_timers = copy.copy(result_module.timers)
self.assertEqual(set(module.timers.keys()), set(keys)) result_module.reset_timers()
self.assertEqual(set(result_module.timers.keys()), set(keys))
for i in keys: for i in keys:
self.assertEqual(module.timers[i], 0) self.assertEqual(result_module.timers[i], 0)
def testAEV(self): def testAEV(self):
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV(
benchmark=True, dtype=self.dtype, device=self.device) benchmark=True, dtype=self.dtype, device=self.device)
self._testModule(aev_computer, [ prepare = torchani.PrepareInput(aev_computer.species, self.device)
run_module = torch.nn.Sequential(prepare, aev_computer)
self._testModule(run_module, aev_computer, [
'terms and indices>radial terms', 'terms and indices>radial terms',
'terms and indices>angular terms', 'terms and indices>angular terms',
'total>terms and indices', 'total>terms and indices',
...@@ -97,9 +97,11 @@ class TestBenchmark(unittest.TestCase): ...@@ -97,9 +97,11 @@ class TestBenchmark(unittest.TestCase):
def testANIModel(self): def testANIModel(self):
aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV(
dtype=self.dtype, device=self.device) dtype=self.dtype, device=self.device)
prepare = torchani.PrepareInput(aev_computer.species, self.device)
model = torchani.models.NeuroChemNNP( model = torchani.models.NeuroChemNNP(
aev_computer, benchmark=True) aev_computer.species, benchmark=True)
self._testModule(model, ['forward>nn']) run_module = torch.nn.Sequential(prepare, aev_computer, model)
self._testModule(run_module, model, ['forward'])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -14,13 +14,16 @@ class TestEnergies(unittest.TestCase): ...@@ -14,13 +14,16 @@ class TestEnergies(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device): device=torchani.default_device):
self.tolerance = 5e-5 self.tolerance = 5e-5
self.aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu')) dtype=dtype, device=torch.device('cpu'))
self.nnp = torchani.models.NeuroChemNNP(self.aev_computer) prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, energies): def _test_molecule(self, coordinates, species, energies):
shift_energy = torchani.EnergyShifter(torchani.buildin_sae_file) shift_energy = torchani.EnergyShifter(torchani.buildin_sae_file)
energies_ = self.nnp(coordinates, species).squeeze() energies_ = self.model((species, coordinates)).squeeze()
energies_ = shift_energy.add_sae(energies_, species) energies_ = shift_energy.add_sae(energies_, species)
max_diff = (energies - energies_).abs().max().item() max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance) self.assertLess(max_diff, self.tolerance)
......
...@@ -19,15 +19,18 @@ class TestEnsemble(unittest.TestCase): ...@@ -19,15 +19,18 @@ class TestEnsemble(unittest.TestCase):
n = torchani.buildin_ensemble n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV(device=torch.device('cpu')) aev = torchani.SortedAEV(device=torch.device('cpu'))
ensemble = torchani.models.NeuroChemNNP(aev, ensemble=True) prepare = torchani.PrepareInput(aev.species, aev.device)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble)
models = [torchani.models. models = [torchani.models.
NeuroChemNNP(aev, ensemble=False, NeuroChemNNP(aev.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i)) from_=prefix + '{}/networks/'.format(i))
for i in range(n)] for i in range(n)]
models = [torch.nn.Sequential(prepare, aev, m) for m in models]
energy1 = ensemble(coordinates, species) energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0] force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m(coordinates, species) for m in models] energy2 = [m((species, coordinates)) for m in models]
energy2 = sum(energy2) / n energy2 = sum(energy2) / n
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0] force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item() energy_diff = (energy1 - energy2).abs().max().item()
......
...@@ -13,14 +13,16 @@ class TestForce(unittest.TestCase): ...@@ -13,14 +13,16 @@ class TestForce(unittest.TestCase):
def setUp(self, dtype=torchani.default_dtype, def setUp(self, dtype=torchani.default_dtype,
device=torchani.default_device): device=torchani.default_device):
self.tolerance = 1e-5 self.tolerance = 1e-5
self.aev_computer = torchani.SortedAEV( aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu')) dtype=dtype, device=torch.device('cpu'))
self.nnp = torchani.models.NeuroChemNNP( prepare = torchani.PrepareInput(aev_computer.species,
self.aev_computer) aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, forces): def _test_molecule(self, coordinates, species, forces):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
energies = self.nnp(coordinates, species) energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), coordinates)[0] derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
max_diff = (forces + derivative).abs().max().item() max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance) self.assertLess(max_diff, self.tolerance)
......
...@@ -26,19 +26,16 @@ if sys.version_info.major >= 3: ...@@ -26,19 +26,16 @@ if sys.version_info.major >= 3:
ds = torch.utils.data.Subset(ds, [0]) ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1) loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer) prepare = torchani.PrepareInput(aev_computer.species,
aev_computer.device)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
def forward(self, x):
return x.flatten()
def __init__(self, model): model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
super(Flatten, self).__init__() batch_nnp = torchani.models.BatchModel(model)
self.model = model
def forward(self, *input):
return self.model(*input).flatten()
nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(container.parameters()) optimizer = torch.optim.Adam(container.parameters())
trainer = create_supervised_trainer( trainer = create_supervised_trainer(
......
...@@ -2,11 +2,12 @@ from .energyshifter import EnergyShifter ...@@ -2,11 +2,12 @@ from .energyshifter import EnergyShifter
from . import models from . import models
from . import data from . import data
from . import ignite from . import ignite
from .aev import SortedAEV from .aev import SortedAEV, PrepareInput
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble, default_dtype, default_device buildin_model_prefix, buildin_ensemble, default_dtype, default_device
__all__ = ['SortedAEV', 'EnergyShifter', 'models', 'data', 'ignite', __all__ = ['PrepareInput', 'SortedAEV', 'EnergyShifter',
'models', 'data', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir', 'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble', 'buildin_model_prefix', 'buildin_ensemble',
'default_dtype', 'default_device'] 'default_dtype', 'default_device']
...@@ -89,6 +89,52 @@ class AEVComputer(BenchmarkedModule): ...@@ -89,6 +89,52 @@ class AEVComputer(BenchmarkedModule):
self.ShfA = self.ShfA.view(1, 1, -1, 1) self.ShfA = self.ShfA.view(1, 1, -1, 1)
self.ShfZ = self.ShfZ.view(1, 1, 1, -1) self.ShfZ = self.ShfZ.view(1, 1, 1, -1)
def forward(self, coordinates_species):
"""Compute AEV from coordinates and species
Parameters
----------
(species, coordinates)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
Returns
-------
(torch.Tensor, torch.LongTensor)
Returns full AEV and species
"""
raise NotImplementedError('subclass must override this method')
class PrepareInput(torch.nn.Module):
def __init__(self, species, device):
super(PrepareInput, self).__init__()
self.species = species
self.device = device
def species_to_tensor(self, species):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
Returns
-------
torch.Tensor
Long tensor for the species, where a value k means the species is
the same as self.species[k].
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=self.device)
def sort_by_species(self, species, *tensors): def sort_by_species(self, species, *tensors):
"""Sort the data by its species according to the order in `self.species` """Sort the data by its species according to the order in `self.species`
...@@ -110,28 +156,10 @@ class AEVComputer(BenchmarkedModule): ...@@ -110,28 +156,10 @@ class AEVComputer(BenchmarkedModule):
new_tensors.append(t.index_select(1, reverse)) new_tensors.append(t.index_select(1, reverse))
return (species, *tensors) return (species, *tensors)
def forward(self, coordinates_species): def forward(self, species_coordinates):
"""Compute AEV from coordinates and species species, coordinates = species_coordinates
species = self.species_to_tensor(species)
Parameters return self.sort_by_species(species, coordinates)
----------
(coordinates, species)
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor
Long tensor for the species, where a value k means the species is
the same as self.species[k]
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor
of `dtype`. The radial AEV must be of shape
(conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
"""
raise NotImplementedError('subclass must override this method')
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
...@@ -194,24 +222,6 @@ class SortedAEV(AEVComputer): ...@@ -194,24 +222,6 @@ class SortedAEV(AEVComputer):
self.assemble = self._enable_benchmark(self.assemble, 'assemble') self.assemble = self._enable_benchmark(self.assemble, 'assemble')
self.forward = self._enable_benchmark(self.forward, 'total') self.forward = self._enable_benchmark(self.forward, 'total')
def species_to_tensor(self, species):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
Returns
-------
torch.Tensor
Long tensor for the species, where a value k means the species is
the same as self.species[k].
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=self.device)
def radial_subaev_terms(self, distances): def radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors """Compute the radial subAEV terms of the center atom given neighbors
...@@ -482,8 +492,8 @@ class SortedAEV(AEVComputer): ...@@ -482,8 +492,8 @@ class SortedAEV(AEVComputer):
return radial_aevs, torch.cat(angular_aevs, dim=2) return radial_aevs, torch.cat(angular_aevs, dim=2)
def forward(self, coordinates_species): def forward(self, species_coordinates):
coordinates, species = coordinates_species species, coordinates = species_coordinates
present_species = species.unique(sorted=True) present_species = species.unique(sorted=True)
radial_terms, angular_terms, indices_r, indices_a = \ radial_terms, angular_terms, indices_r, indices_a = \
...@@ -497,4 +507,4 @@ class SortedAEV(AEVComputer): ...@@ -497,4 +507,4 @@ class SortedAEV(AEVComputer):
radial, angular = self.assemble(radial_terms, angular_terms, radial, angular = self.assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a) present_species, mask_r, mask_a)
fullaev = torch.cat([radial, angular], dim=2) fullaev = torch.cat([radial, angular], dim=2)
return fullaev return species, fullaev
from ..aev import AEVComputer
import torch import torch
from ..benchmarked import BenchmarkedModule from ..benchmarked import BenchmarkedModule
...@@ -9,10 +8,10 @@ class ANIModel(BenchmarkedModule): ...@@ -9,10 +8,10 @@ class ANIModel(BenchmarkedModule):
Attributes Attributes
---------- ----------
aev_computer : AEVComputer species : list
The AEV computer. Chemical symbol of supported atom species.
output_length : int output_length : int
The length of output vector The length of output vector.
suffixes : sequence suffixes : sequence
Different suffixes denote different models in an ensemble. Different suffixes denote different models in an ensemble.
model_<X><suffix> : nn.Module model_<X><suffix> : nn.Module
...@@ -26,30 +25,15 @@ class ANIModel(BenchmarkedModule): ...@@ -26,30 +25,15 @@ class ANIModel(BenchmarkedModule):
the tensor containing desired output. the tensor containing desired output.
output_length : int output_length : int
Length of output of each submodel. Length of output of each submodel.
derivative : boolean
Whether to support computing the derivative w.r.t coordinates,
i.e. d(output)/dR
derivative_graph : boolean
Whether to generate a graph for the derivative. This would be required
only if the derivative is included as part of the loss function.
timers : dict timers : dict
Dictionary storing the the benchmark result. It has the following keys: Dictionary storing the the benchmark result. It has the following keys:
aev : time spent on computing AEV.
nn : time spent on computing output from AEV.
derivative : time spend on computing derivative w.r.t. coordinates
after the outputs is given. This key is only available if
derivative computation is turned on.
forward : total time for the forward pass forward : total time for the forward pass
""" """
def __init__(self, aev_computer, suffixes, reducer, output_length, models, def __init__(self, species, suffixes, reducer, output_length, models,
benchmark=False): benchmark=False):
super(ANIModel, self).__init__(benchmark) super(ANIModel, self).__init__(benchmark)
if not isinstance(aev_computer, AEVComputer): self.species = species
raise TypeError(
"ModelOnAEV: aev_computer must be a subclass of AEVComputer")
self.aev_computer = aev_computer
self.suffixes = suffixes self.suffixes = suffixes
self.reducer = reducer self.reducer = reducer
self.output_length = output_length self.output_length = output_length
...@@ -57,20 +41,19 @@ class ANIModel(BenchmarkedModule): ...@@ -57,20 +41,19 @@ class ANIModel(BenchmarkedModule):
setattr(self, i, models[i]) setattr(self, i, models[i])
if benchmark: if benchmark:
self.aev_to_output = self._enable_benchmark(
self.aev_to_output, 'nn')
self.forward = self._enable_benchmark(self.forward, 'forward') self.forward = self._enable_benchmark(self.forward, 'forward')
def aev_to_output(self, aev, species): def forward(self, species_aev):
"""Compute output from aev """Compute output from aev
Parameters Parameters
---------- ----------
(species, aev)
species : torch.Tensor
Tensor storing the species for each atom.
aev : torch.Tensor aev : torch.Tensor
Pytorch tensor of shape (conformations, atoms, aev_length) storing Pytorch tensor of shape (conformations, atoms, aev_length) storing
the computed AEVs. the computed AEVs.
species : torch.Tensor
Tensor storing the species for each atom.
Returns Returns
------- -------
...@@ -78,6 +61,7 @@ class ANIModel(BenchmarkedModule): ...@@ -78,6 +61,7 @@ class ANIModel(BenchmarkedModule):
Pytorch tensor of shape (conformations, output_length) for the Pytorch tensor of shape (conformations, output_length) for the
output of each conformation. output of each conformation.
""" """
species, aev = species_aev
conformations = aev.shape[0] conformations = aev.shape[0]
atoms = len(species) atoms = len(species)
rev_species = species.__reversed__() rev_species = species.__reversed__()
...@@ -88,11 +72,11 @@ class ANIModel(BenchmarkedModule): ...@@ -88,11 +72,11 @@ class ANIModel(BenchmarkedModule):
for s in species_dedup: for s in species_dedup:
begin = species.index(s) begin = species.index(s)
end = atoms - rev_species.index(s) end = atoms - rev_species.index(s)
y = aev[:, begin:end, :].reshape(-1, self.aev_computer.aev_length) y = aev[:, begin:end, :].flatten(0, 1)
def apply_model(suffix): def apply_model(suffix):
model_X = getattr(self, 'model_' + model_X = getattr(self, 'model_' +
self.aev_computer.species[s] + suffix) self.species[s] + suffix)
return model_X(y) return model_X(y)
ys = [apply_model(suffix) for suffix in self.suffixes] ys = [apply_model(suffix) for suffix in self.suffixes]
y = sum(ys) / len(ys) y = sum(ys) / len(ys)
...@@ -102,26 +86,3 @@ class ANIModel(BenchmarkedModule): ...@@ -102,26 +86,3 @@ class ANIModel(BenchmarkedModule):
per_species_outputs = torch.cat(per_species_outputs, dim=1) per_species_outputs = torch.cat(per_species_outputs, dim=1)
molecule_output = self.reducer(per_species_outputs, dim=1) molecule_output = self.reducer(per_species_outputs, dim=1)
return molecule_output return molecule_output
def forward(self, coordinates, species):
"""Feed forward
Parameters
----------
coordinates : torch.Tensor
The pytorch tensor of shape (conformations, atoms, 3) storing
the coordinates of all atoms of all conformations.
species : list of string
List of string storing the species for each atom.
Returns
-------
torch.Tensor
Tensor of shape (conformations, output_length) for the
output of each conformation.
"""
species = self.aev_computer.species_to_tensor(species)
_species, _coordinates, = self.aev_computer.sort_by_species(
species, coordinates)
aev = self.aev_computer((_coordinates, _species))
return self.aev_to_output(aev, _species)
...@@ -12,5 +12,5 @@ class BatchModel(torch.nn.Module): ...@@ -12,5 +12,5 @@ class BatchModel(torch.nn.Module):
for i in batch: for i in batch:
coordinates = i['coordinates'] coordinates = i['coordinates']
species = i['species'] species = i['species']
results.append(self.model(coordinates, species)) results.append(self.model((species, coordinates)))
return torch.cat(results) return torch.cat(results)
...@@ -3,7 +3,7 @@ from .ani_model import ANIModel ...@@ -3,7 +3,7 @@ from .ani_model import ANIModel
class CustomModel(ANIModel): class CustomModel(ANIModel):
def __init__(self, aev_computer, per_species, reducer, def __init__(self, per_species, reducer,
derivative=False, derivative_graph=False, benchmark=False): derivative=False, derivative_graph=False, benchmark=False):
"""Custom single model, no ensemble """Custom single model, no ensemble
...@@ -31,7 +31,8 @@ class CustomModel(ANIModel): ...@@ -31,7 +31,8 @@ class CustomModel(ANIModel):
raise ValueError( raise ValueError(
'''output length of each atomic neural network must '''output length of each atomic neural network must
match''') match''')
super(CustomModel, self).__init__(aev_computer, suffixes, reducer, super(CustomModel, self).__init__(list(per_species.keys()), suffixes,
output_length, models, benchmark) reducer, output_length, models,
benchmark)
for i in per_species: for i in per_species:
setattr(self, 'model_' + i, per_species[i]) setattr(self, 'model_' + i, per_species[i])
...@@ -12,10 +12,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -12,10 +12,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
Attributes Attributes
---------- ----------
dtype : torch.dtype
Pytorch data type for tensors
device : torch.Device
The device where tensors should be.
layers : int layers : int
Number of layers. Number of layers.
output_length : int output_length : int
...@@ -29,13 +25,11 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -29,13 +25,11 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
The NeuroChem index for activation. The NeuroChem index for activation.
""" """
def __init__(self, dtype, device, filename): def __init__(self, filename):
"""Initialize from NeuroChem network directory. """Initialize from NeuroChem network directory.
Parameters Parameters
---------- ----------
dtype : torch.dtype
Pytorch data type for tensors
filename : string filename : string
The file name for the `.nnf` file that store network The file name for the `.nnf` file that store network
hyperparameters. The `.bparam` and `.wparam` must be hyperparameters. The `.bparam` and `.wparam` must be
...@@ -43,8 +37,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -43,8 +37,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
""" """
super(NeuroChemAtomicNetwork, self).__init__() super(NeuroChemAtomicNetwork, self).__init__()
self.dtype = dtype
self.device = device
networ_dir = os.path.dirname(filename) networ_dir = os.path.dirname(filename)
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
buffer = f.read() buffer = f.read()
...@@ -204,7 +196,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -204,7 +196,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
raise NotImplementedError( raise NotImplementedError(
'''different activation on different '''different activation on different
layers are not supported''') layers are not supported''')
linear = torch.nn.Linear(in_size, out_size).type(self.dtype) linear = torch.nn.Linear(in_size, out_size)
name = 'layer{}'.format(i) name = 'layer{}'.format(i)
setattr(self, name, linear) setattr(self, name, linear)
if in_size * out_size != wsz or out_size != bsz: if in_size * out_size != wsz or out_size != bsz:
...@@ -219,14 +211,12 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -219,14 +211,12 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
wsize = in_size * out_size wsize = in_size * out_size
fw = open(wfn, 'rb') fw = open(wfn, 'rb')
w = struct.unpack('{}f'.format(wsize), fw.read()) w = struct.unpack('{}f'.format(wsize), fw.read())
w = torch.tensor(w, dtype=self.dtype, device=self.device).view( w = torch.tensor(w).view(out_size, in_size)
out_size, in_size)
linear.weight.data = w linear.weight.data = w
fw.close() fw.close()
fb = open(bfn, 'rb') fb = open(bfn, 'rb')
b = struct.unpack('{}f'.format(out_size), fb.read()) b = struct.unpack('{}f'.format(out_size), fb.read())
b = torch.tensor(b, dtype=self.dtype, b = torch.tensor(b).view(out_size)
device=self.device).view(out_size)
linear.bias.data = b linear.bias.data = b
fb.close() fb.close()
......
...@@ -7,8 +7,7 @@ from ..env import buildin_network_dir, buildin_model_prefix, buildin_ensemble ...@@ -7,8 +7,7 @@ from ..env import buildin_network_dir, buildin_model_prefix, buildin_ensemble
class NeuroChemNNP(ANIModel): class NeuroChemNNP(ANIModel):
def __init__(self, aev_computer, from_=None, ensemble=False, def __init__(self, species, from_=None, ensemble=False, benchmark=False):
derivative=False, derivative_graph=False, benchmark=False):
"""If from_=None then ensemble must be a boolean. If ensemble=False, """If from_=None then ensemble must be a boolean. If ensemble=False,
then use buildin network0, else use buildin network ensemble. then use buildin network0, else use buildin network ensemble.
If from_ != None, ensemble must be either False or an integer If from_ != None, ensemble must be either False or an integer
...@@ -46,12 +45,10 @@ class NeuroChemNNP(ANIModel): ...@@ -46,12 +45,10 @@ class NeuroChemNNP(ANIModel):
models = {} models = {}
output_length = None output_length = None
for network_dir, suffix in zip(network_dirs, suffixes): for network_dir, suffix in zip(network_dirs, suffixes):
for i in aev_computer.species: for i in species:
filename = os.path.join( filename = os.path.join(
network_dir, 'ANN-{}.nnf'.format(i)) network_dir, 'ANN-{}.nnf'.format(i))
model_X = NeuroChemAtomicNetwork( model_X = NeuroChemAtomicNetwork(filename)
aev_computer.dtype, aev_computer.device,
filename)
if output_length is None: if output_length is None:
output_length = model_X.output_length output_length = model_X.output_length
elif output_length != model_X.output_length: elif output_length != model_X.output_length:
...@@ -59,5 +56,5 @@ class NeuroChemNNP(ANIModel): ...@@ -59,5 +56,5 @@ class NeuroChemNNP(ANIModel):
'''output length of each atomic neural networt '''output length of each atomic neural networt
must match''') must match''')
models['model_' + i + suffix] = model_X models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(aev_computer, suffixes, reducer, super(NeuroChemNNP, self).__init__(species, suffixes, reducer,
output_length, models, benchmark) output_length, models, benchmark)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment