Unverified Commit 13f53da8 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Cherry-pick roitberg-group#10 (#545)



* Add convenience functions useful for active learning [WIP] (#10)

* Add convenience functions useful for active learning

* avoid training outputs

* modify gitignore

* Add convenience functions to directly get atomic energies

* fix bug

* fix mypy

* flake8

* fix bugs

* flake8

* mypy

* Add tests for functions

* add test to workflows and flake8

* empty to trigger tests

* trigger

* delete new test

* readd new test

* avoid training outputs

* trigger

* trigger tests again, they are all passing on my side

* fix isclose in tests

* save

* fix
Co-authored-by: default avatarIgnacio Pickering <ign.pickering@gmail.com>
parent 20b5746a
......@@ -21,7 +21,7 @@ jobs:
test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py,
test_grad.py, test_cuaev.py]
test_grad.py, test_cuaev.py, test_al.py]
steps:
- uses: actions/checkout@v1
......
......@@ -38,3 +38,4 @@ Untitled.ipynb
.coverage
htmlcov/
/include
training_outputs/
......@@ -57,3 +57,19 @@ force = -derivative
# And print to see the result:
print('Energy:', energy.item())
print('Force:', force.squeeze())
###############################################################################
# you can also get the atomic energies (WARNING: these have no physical
# meaning) by calling:
_, atomic_energies = model.atomic_energies((species, coordinates))
###############################################################################
# this gives you the average (shifted) energies over all models of the ensemble by default,
# with the same shape as the coordinates. Dummy atoms, if present, will have an
# energy of zero
print('Average Atomic energies, for species 6 1 1 1 1', atomic_energies)
###############################################################################
# you can also access model specific atomic energies
_, atomic_energies = model.atomic_energies((species, coordinates), average=False)
print('Atomic energies of first model, for species 6 1 1 1 1', atomic_energies[0, :, :])
import torch
import torchani
import math
import unittest
class TestALAtomic(unittest.TestCase):
def setUp(self):
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.model = torchani.models.ANI1x(periodic_table_index=True).to(
self.device).double()
self.converter = torchani.nn.SpeciesConverter(['H', 'C', 'N', 'O'])
self.aev_computer = self.model.aev_computer
self.ani_model = self.model.neural_networks
self.first_model = self.model[0]
# fully symmetric methane
self.coordinates = torch.tensor(
[[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [0.5, 0.5, 0.5]],
dtype=torch.double,
device=self.device).unsqueeze(0)
self.species = torch.tensor([[1, 1, 1, 1, 6]],
dtype=torch.long,
device=self.device)
def testAverageAtomicEnergies(self):
_, energies = self.model.atomic_energies(
(self.species, self.coordinates))
self.assertTrue(energies.shape == self.coordinates.shape[:-1])
# energies of all hydrogens should be equal
self.assertTrue((torch.isclose(
energies[:, :-1],
torch.tensor(-0.54853380570289400620,
dtype=torch.double).to(self.device))).all())
def testAtomicEnergies(self):
_, energies = self.model.atomic_energies(
(self.species, self.coordinates), average=False)
self.assertTrue(energies.shape[1:] == self.coordinates.shape[:-1])
self.assertTrue(energies.shape[0] == len(self.model.neural_networks))
# energies of all hydrogens should be equal
self.assertTrue(
torch.isclose(
energies[0, 0, 0],
torch.tensor(-0.54562734428531045605,
device=self.device,
dtype=torch.double)))
for e in energies:
self.assertTrue((e[:, :-1] == e[:, 0]).all())
class TestALQBC(TestALAtomic):
def testMemberEnergies(self):
# fully symmetric methane
_, energies = self.model.members_energies(
(self.species, self.coordinates))
# correctness of shape
torch.set_printoptions(precision=15)
self.assertTrue(energies.shape[-1] == self.coordinates.shape[0])
self.assertTrue(energies.shape[0] == len(self.model.neural_networks))
self.assertTrue(
energies[0] == self.first_model((self.species,
self.coordinates)).energies)
self.assertTrue(
torch.isclose(
energies[0],
torch.tensor(-40.277153758433975,
dtype=torch.double,
device=self.device)))
def testQBC(self):
# fully symmetric methane
_, _, qbc = self.model.energies_qbcs((self.species, self.coordinates))
torch.set_printoptions(precision=15)
std = self.model.members_energies(
(self.species, self.coordinates)).energies.std(dim=0,
unbiased=True)
self.assertTrue(
torch.isclose(std / math.sqrt(self.coordinates.shape[1]), qbc))
# also test with multiple coordinates
coord1 = torch.tensor(
[[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [0.5, 0.5, 0.5]],
dtype=torch.double,
device=self.device).unsqueeze(0)
coord2 = torch.randn(1, 5, 3, dtype=torch.double, device=self.device)
coordinates = torch.cat((coord1, coord2), dim=0)
species = torch.tensor([[1, 1, 1, 1, 6], [-1, 1, 1, 1, 1]],
dtype=torch.long,
device=self.device)
std = self.model.members_energies(
(species, coordinates)).energies.std(dim=0, unbiased=True)
_, _, qbc = self.model.energies_qbcs((species, coordinates))
std[0] = std[0] / math.sqrt(5)
std[1] = std[1] / math.sqrt(4)
self.assertTrue(torch.isclose(std, qbc).all())
if __name__ == '__main__':
unittest.main()
......@@ -35,12 +35,18 @@ import zipfile
import torch
from distutils import dir_util
from torch import Tensor
from typing import Tuple, Optional
from . import neurochem
from typing import Tuple, Optional, NamedTuple
from .nn import SpeciesConverter, SpeciesEnergies
from .aev import AEVComputer
class SpeciesEnergiesQBC(NamedTuple):
species: Tensor
energies: Tensor
qbcs: Tensor
class BuiltinModel(torch.nn.Module):
r"""Private template for the builtin ANI models """
......@@ -148,6 +154,44 @@ class BuiltinModel(torch.nn.Module):
species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies)
@torch.jit.export
def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""Calculates predicted atomic energies of all atoms in a molecule
..warning::
Since this function does not call ``__call__`` directly,
hooks are not registered and profiling is not done correctly by
pytorch on it. It is meant as a convenience function for analysis
and active learning.
.. note:: The coordinates, and cell are in Angstrom, and the energies
will be in Hartree.
Args:
species_coordinates: minibatch of configurations
cell: the cell used in PBC computation, set to None if PBC is not enabled
pbc: the bool tensor indicating which direction PBC is enabled, set to None if PBC is not enabled
Returns:
species_atomic_energies: species and energies for the given configurations
note that the shape of species is (C, A), where C is
the number of configurations and A the number of atoms, and
the shape of energies is (C, A) for a BuiltinModel.
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
atomic_energies = self.neural_networks._atomic_energies((species, aevs))
self_energies = self.energy_shifter.self_energies.clone().to(species.device)
self_energies = self_energies[species]
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double)
# shift all atomic energies individually
assert self_energies.shape == atomic_energies.shape
atomic_energies += self_energies
return SpeciesEnergies(species, atomic_energies)
@torch.jit.export
def _recast_long_buffers(self):
self.species_converter.conv_tensor = self.species_converter.conv_tensor.to(dtype=torch.long)
......@@ -225,6 +269,36 @@ class BuiltinEnsemble(BuiltinModel):
energy_shifter, species_to_tensor, consts, sae_dict,
periodic_table_index)
@torch.jit.export
def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None, average: bool = True) -> SpeciesEnergies:
"""Calculates predicted atomic energies of all atoms in a molecule
see `:method:torchani.BuiltinModel.atomic_energies`
If average is True (the default) it returns the average over all models
(shape (C, A)), otherwise it returns one atomic energy per model (shape
(M, C, A))
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
members_list = []
for nnp in self.neural_networks:
members_list.append(nnp._atomic_energies((species, aevs)).unsqueeze(0))
member_atomic_energies = torch.cat(members_list, dim=0)
self_energies = self.energy_shifter.self_energies.clone().to(species.device)
self_energies = self_energies[species]
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double)
# shift all atomic energies individually
assert self_energies.shape == member_atomic_energies.shape[1:]
member_atomic_energies += self_energies
if average:
return SpeciesEnergies(species, member_atomic_energies.mean(dim=0))
return SpeciesEnergies(species, member_atomic_energies)
@classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False):
# this is used to load only 1 model (by default model 0)
......@@ -263,6 +337,95 @@ class BuiltinEnsemble(BuiltinModel):
self.periodic_table_index)
return ret
@torch.jit.export
def members_energies(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""Calculates predicted energies of all member modules
..warning::
Since this function does not call ``__call__`` directly,
hooks are not registered and profiling is not done correctly by
pytorch on it. It is meant as a convenience function for analysis
and active learning.
.. note:: The coordinates, and cell are in Angstrom, and the energies
will be in Hartree.
Args:
species_coordinates: minibatch of configurations
cell: the cell used in PBC computation, set to None if PBC is not enabled
pbc: the bool tensor indicating which direction PBC is enabled, set to None if PBC is not enabled
Returns:
species_energies: species and energies for the given configurations
note that the shape of species is (C, A), where C is
the number of configurations and A the number of atoms, and
the shape of energies is (M, C), where M is the number
of modules in the ensemble
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
member_outputs = []
for nnp in self.neural_networks:
unshifted_energies = nnp((species, aevs)).energies
shifted_energies = self.energy_shifter((species, unshifted_energies)).energies
member_outputs.append(shifted_energies.unsqueeze(0))
return SpeciesEnergies(species, torch.cat(member_outputs, dim=0))
@torch.jit.export
def energies_qbcs(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None, unbiased: bool = True) -> SpeciesEnergiesQBC:
"""Calculates predicted predicted energies and qbc factors
QBC factors are used for query-by-committee (QBC) based active learning
(as described in the ANI-1x paper `less-is-more`_ ).
.. _less-is-more:
https://aip.scitation.org/doi/10.1063/1.5023802
..warning::
Since this function does not call ``__call__`` directly,
hooks are not registered and profiling is not done correctly by
pytorch on it. It is meant as a convenience function for analysis
and active learning.
.. note:: The coordinates, and cell are in Angstrom, and the energies
and qbc factors will be in Hartree.
Args:
species_coordinates: minibatch of configurations
cell: the cell used in PBC computation, set to None if PBC is not
enabled
pbc: the bool tensor indicating which direction PBC is enabled, set
to None if PBC is not enabled
unbiased: if `True` then Bessel's correction is applied to the
standard deviation over the ensemble member's. If `False` Bessel's
correction is not applied, True by default.
Returns:
species_energies_qbcs: species, energies and qbc factors for the
given configurations note that the shape of species is (C, A),
where C is the number of configurations and A the number of
atoms, the shape of energies is (C,) and the shape of qbc
factors is also (C,).
"""
species, energies = self.members_energies(species_coordinates, cell, pbc)
# standard deviation is taken across ensemble members
qbc_factors = energies.std(0, unbiased=unbiased)
# rho's (qbc factors) are weighted by dividing by the square root of
# the number of atoms in each molecule
num_atoms = (species >= 0).sum(dim=1, dtype=energies.dtype)
qbc_factors = qbc_factors / num_atoms.sqrt()
energies = energies.mean(dim=0)
assert qbc_factors.shape == energies.shape
return SpeciesEnergiesQBC(species, energies, qbc_factors)
def __len__(self):
"""Get the number of networks in the ensemble
......
......@@ -55,21 +55,28 @@ class ANIModel(torch.nn.ModuleDict):
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
species, aev = species_aev
assert species.shape == aev.shape[:-1]
atomic_energies = self._atomic_energies((species, aev))
# shape of atomic energies is (C, A)
return SpeciesEnergies(species, torch.sum(atomic_energies, dim=1))
@torch.jit.export
def _atomic_energies(self, species_aev: Tuple[Tensor, Tensor]) -> Tensor:
# Obtain the atomic energies associated with a given tensor of AEV's
species, aev = species_aev
assert species.shape == aev.shape[:-1]
species_ = species.flatten()
aev = aev.flatten(0, 1)
output = aev.new_zeros(species_.shape)
for i, (_, m) in enumerate(self.items()):
for i, m in enumerate(self.values()):
mask = (species_ == i)
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species)
return SpeciesEnergies(species, torch.sum(output, dim=1))
return output
class Ensemble(torch.nn.ModuleList):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment