Unverified Commit 6b058c6e authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove everything about chunking (#432)

* Remove everything about chunking

* aev.py

* neurochem trainer

* training-benchmark-nsys-profile.py

* fix eval

* training-benchmark.py

* nnp_training.py

* flake8

* nnp_training_force.py

* fix dtype of species

* fix

* flake8

* requires_grad_

* git ignore

* fix

* original

* fix

* fix

* fix

* fix

* save

* save

* save

* save

* save

* save

* save

* save

* save

* collate

* fix

* save

* fix

* save

* save

* fix

* save

* fix

* fix

* no len

* float

* save

* save

* save

* save

* save

* save

* save

* save

* save

* fix

* save

* save

* save

* save

* fix

* fix

* fix

* fix mypy

* don't remove outliers

* save

* save

* save

* fix

* flake8

* save

* fix

* flake8

* docs

* more docs

* fix test_data

* remove test_data_new

* fix
parent 338f896a
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from collections import OrderedDict from collections import OrderedDict
from torch import Tensor from torch import Tensor
from typing import Tuple, NamedTuple, Optional from typing import Tuple, NamedTuple, Optional
from . import utils
class SpeciesEnergies(NamedTuple): class SpeciesEnergies(NamedTuple):
...@@ -109,19 +110,9 @@ class Gaussian(torch.nn.Module): ...@@ -109,19 +110,9 @@ class Gaussian(torch.nn.Module):
class SpeciesConverter(torch.nn.Module): class SpeciesConverter(torch.nn.Module):
"""Convert from element index in the periodic table to 0, 1, 2, 3, ...""" """Convert from element index in the periodic table to 0, 1, 2, 3, ..."""
periodic_table = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()
def __init__(self, species): def __init__(self, species):
super().__init__() super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)} rev_idx = {s: k for k, s in enumerate(utils.PERIODIC_TABLE, 1)}
maxidx = max(rev_idx.values()) maxidx = max(rev_idx.values())
self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long)) self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long))
for i, s in enumerate(species): for i, s in enumerate(species):
......
...@@ -2,45 +2,73 @@ import torch ...@@ -2,45 +2,73 @@ import torch
from torch import Tensor from torch import Tensor
import torch.utils.data import torch.utils.data
import math import math
import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Tuple, NamedTuple, Optional from typing import Tuple, NamedTuple, Optional
from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2fconst from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2fconst
from .nn import SpeciesEnergies from .nn import SpeciesEnergies
def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda: 0.0, species=-1)): def stack_with_padding(properties, padding):
output = defaultdict(lambda: [])
for p in properties:
for k, v in p.items():
output[k].append(v)
for k, v in output.items():
if v[0].dim() == 0:
output[k] = torch.stack(v)
else:
output[k] = torch.nn.utils.rnn.pad_sequence(v, True, padding[k])
return output
def broadcast_first_dim(properties):
num_molecule = 1
for k, v in properties.items():
shape = list(v.shape)
n = shape[0]
if num_molecule != 1:
assert n == 1 or n == num_molecule, "unable to broadcast"
else:
num_molecule = n
for k, v in properties.items():
shape = list(v.shape)
shape[0] = num_molecule
properties[k] = v.expand(shape)
return properties
def pad_atomic_properties(properties, padding_values=defaultdict(lambda: 0.0, species=-1)):
"""Put a sequence of atomic properties together into single tensor. """Put a sequence of atomic properties together into single tensor.
Inputs are `[{'species': ..., ...}, {'species': ..., ...}, ...]` and the outputs Inputs are `[{'species': ..., ...}, {'species': ..., ...}, ...]` and the outputs
are `{'species': padded_tensor, ...}` are `{'species': padded_tensor, ...}`
Arguments: Arguments:
species_coordinates (:class:`collections.abc.Sequence`): sequence of properties (:class:`collections.abc.Sequence`): sequence of properties.
atomic properties.
padding_values (dict): the value to fill to pad tensors to same size padding_values (dict): the value to fill to pad tensors to same size
""" """
keys = list(atomic_properties[0]) vectors = [k for k in properties[0].keys() if properties[0][k].dim() > 1]
anykey = keys[0] scalars = [k for k in properties[0].keys() if properties[0][k].dim() == 1]
max_atoms = max(x[anykey].shape[1] for x in atomic_properties) padded_sizes = {k: max(x[k].shape[1] for x in properties) for k in vectors}
padded = {k: [] for k in keys} num_molecules = [x[vectors[0]].shape[0] for x in properties]
for p in atomic_properties: total_num_molecules = sum(num_molecules)
num_molecules = 1 output = {}
for v in p.values(): for k in scalars:
assert num_molecules in {1, v.shape[0]}, 'Number of molecules in different atomic properties mismatch' output[k] = torch.stack([x[k] for x in properties])
if v.shape[0] != 1: for k in vectors:
num_molecules = v.shape[0] tensor = properties[0][k]
for k, v in p.items(): shape = list(tensor.shape)
shape = list(v.shape) device = tensor.device
padatoms = max_atoms - shape[1] dtype = tensor.dtype
shape[1] = padatoms shape[0] = total_num_molecules
padding = v.new_full(shape, padding_values[k]) shape[1] = padded_sizes[k]
v = torch.cat([v, padding], dim=1) output[k] = torch.full(shape, padding_values[k], device=device, dtype=dtype)
shape = list(v.shape) index0 = 0
shape[0] = num_molecules for n, x in zip(num_molecules, properties):
v = v.expand(*shape) original_size = x[k].shape[1]
padded[k].append(v) output[k][index0: index0 + n, 0: original_size, ...] = x[k]
return {k: torch.cat(v) for k, v in padded.items()} index0 += n
return output
# @torch.jit.script # @torch.jit.script
...@@ -132,24 +160,6 @@ class EnergyShifter(torch.nn.Module): ...@@ -132,24 +160,6 @@ class EnergyShifter(torch.nn.Module):
self.register_buffer('self_energies', self_energies) self.register_buffer('self_energies', self_energies)
def sae_from_dataset(self, atomic_properties, properties):
"""Compute atomic self energies from dataset.
Least-squares solution to a linear equation is calculated to output
``self_energies`` when ``self_energies = None`` is passed to
:class:`torchani.EnergyShifter`
"""
species = atomic_properties['species']
energies = properties['energies']
present_species_ = present_species(species)
X = (species.unsqueeze(-1) == present_species_).sum(dim=1).to(torch.double)
# Concatenate a vector of ones to find fit intercept
if self.fit_intercept:
X = torch.cat((X, torch.ones(X.shape[0], 1).to(torch.double)), dim=-1)
y = energies.unsqueeze(dim=-1)
coeff_, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
return coeff_.squeeze(-1)
def sae(self, species): def sae(self, species):
"""Compute self energies for molecules. """Compute self energies for molecules.
...@@ -171,19 +181,6 @@ class EnergyShifter(torch.nn.Module): ...@@ -171,19 +181,6 @@ class EnergyShifter(torch.nn.Module):
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double) self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double)
return self_energies.sum(dim=1) + intercept return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer that subtracts self energies from a dataset"""
if self.self_energies is None:
self_energies = self.sae_from_dataset(atomic_properties, properties)
self.self_energies = torch.tensor(self_energies, dtype=torch.double)
species = atomic_properties['species']
energies = properties['energies']
device = energies.device
energies = energies.to(torch.double) - self.sae(species).to(device)
properties['energies'] = energies
return atomic_properties, properties
def forward(self, species_energies: Tuple[Tensor, Tensor], def forward(self, species_energies: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
...@@ -414,6 +411,17 @@ def get_atomic_masses(species): ...@@ -414,6 +411,17 @@ def get_atomic_masses(species):
return masses return masses
PERIODIC_TABLE = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()
__all__ = ['pad_atomic_properties', 'present_species', 'hessian', __all__ = ['pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding', 'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts', 'get_atomic_masses'] 'ChemicalSymbolsToInts', 'get_atomic_masses']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment