".github/vscode:/vscode.git/clone" did not exist on "ea7a96ddb3d3e75f267e5f07f77f4cbc42f6eb5b"
Unverified Commit 6b058c6e authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove everything about chunking (#432)

* Remove everything about chunking

* aev.py

* neurochem trainer

* training-benchmark-nsys-profile.py

* fix eval

* training-benchmark.py

* nnp_training.py

* flake8

* nnp_training_force.py

* fix dtype of species

* fix

* flake8

* requires_grad_

* git ignore

* fix

* original

* fix

* fix

* fix

* fix

* save

* save

* save

* save

* save

* save

* save

* save

* save

* collate

* fix

* save

* fix

* save

* save

* fix

* save

* fix

* fix

* no len

* float

* save

* save

* save

* save

* save

* save

* save

* save

* save

* fix

* save

* save

* save

* save

* fix

* fix

* fix

* fix mypy

* don't remove outliers

* save

* save

* save

* fix

* flake8

* save

* fix

* flake8

* docs

* more docs

* fix test_data

* remove test_data_new

* fix
parent 338f896a
......@@ -2,6 +2,7 @@ import torch
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional
from . import utils
class SpeciesEnergies(NamedTuple):
......@@ -109,19 +110,9 @@ class Gaussian(torch.nn.Module):
class SpeciesConverter(torch.nn.Module):
"""Convert from element index in the periodic table to 0, 1, 2, 3, ..."""
periodic_table = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()
def __init__(self, species):
super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
rev_idx = {s: k for k, s in enumerate(utils.PERIODIC_TABLE, 1)}
maxidx = max(rev_idx.values())
self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long))
for i, s in enumerate(species):
......
......@@ -2,45 +2,73 @@ import torch
from torch import Tensor
import torch.utils.data
import math
import numpy as np
from collections import defaultdict
from typing import Tuple, NamedTuple, Optional
from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2fconst
from .nn import SpeciesEnergies
def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda: 0.0, species=-1)):
def stack_with_padding(properties, padding):
output = defaultdict(lambda: [])
for p in properties:
for k, v in p.items():
output[k].append(v)
for k, v in output.items():
if v[0].dim() == 0:
output[k] = torch.stack(v)
else:
output[k] = torch.nn.utils.rnn.pad_sequence(v, True, padding[k])
return output
def broadcast_first_dim(properties):
num_molecule = 1
for k, v in properties.items():
shape = list(v.shape)
n = shape[0]
if num_molecule != 1:
assert n == 1 or n == num_molecule, "unable to broadcast"
else:
num_molecule = n
for k, v in properties.items():
shape = list(v.shape)
shape[0] = num_molecule
properties[k] = v.expand(shape)
return properties
def pad_atomic_properties(properties, padding_values=defaultdict(lambda: 0.0, species=-1)):
"""Put a sequence of atomic properties together into single tensor.
Inputs are `[{'species': ..., ...}, {'species': ..., ...}, ...]` and the outputs
are `{'species': padded_tensor, ...}`
Arguments:
species_coordinates (:class:`collections.abc.Sequence`): sequence of
atomic properties.
properties (:class:`collections.abc.Sequence`): sequence of properties.
padding_values (dict): the value to fill to pad tensors to same size
"""
keys = list(atomic_properties[0])
anykey = keys[0]
max_atoms = max(x[anykey].shape[1] for x in atomic_properties)
padded = {k: [] for k in keys}
for p in atomic_properties:
num_molecules = 1
for v in p.values():
assert num_molecules in {1, v.shape[0]}, 'Number of molecules in different atomic properties mismatch'
if v.shape[0] != 1:
num_molecules = v.shape[0]
for k, v in p.items():
shape = list(v.shape)
padatoms = max_atoms - shape[1]
shape[1] = padatoms
padding = v.new_full(shape, padding_values[k])
v = torch.cat([v, padding], dim=1)
shape = list(v.shape)
shape[0] = num_molecules
v = v.expand(*shape)
padded[k].append(v)
return {k: torch.cat(v) for k, v in padded.items()}
vectors = [k for k in properties[0].keys() if properties[0][k].dim() > 1]
scalars = [k for k in properties[0].keys() if properties[0][k].dim() == 1]
padded_sizes = {k: max(x[k].shape[1] for x in properties) for k in vectors}
num_molecules = [x[vectors[0]].shape[0] for x in properties]
total_num_molecules = sum(num_molecules)
output = {}
for k in scalars:
output[k] = torch.stack([x[k] for x in properties])
for k in vectors:
tensor = properties[0][k]
shape = list(tensor.shape)
device = tensor.device
dtype = tensor.dtype
shape[0] = total_num_molecules
shape[1] = padded_sizes[k]
output[k] = torch.full(shape, padding_values[k], device=device, dtype=dtype)
index0 = 0
for n, x in zip(num_molecules, properties):
original_size = x[k].shape[1]
output[k][index0: index0 + n, 0: original_size, ...] = x[k]
index0 += n
return output
# @torch.jit.script
......@@ -132,24 +160,6 @@ class EnergyShifter(torch.nn.Module):
self.register_buffer('self_energies', self_energies)
def sae_from_dataset(self, atomic_properties, properties):
"""Compute atomic self energies from dataset.
Least-squares solution to a linear equation is calculated to output
``self_energies`` when ``self_energies = None`` is passed to
:class:`torchani.EnergyShifter`
"""
species = atomic_properties['species']
energies = properties['energies']
present_species_ = present_species(species)
X = (species.unsqueeze(-1) == present_species_).sum(dim=1).to(torch.double)
# Concatenate a vector of ones to find fit intercept
if self.fit_intercept:
X = torch.cat((X, torch.ones(X.shape[0], 1).to(torch.double)), dim=-1)
y = energies.unsqueeze(dim=-1)
coeff_, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
return coeff_.squeeze(-1)
def sae(self, species):
"""Compute self energies for molecules.
......@@ -171,19 +181,6 @@ class EnergyShifter(torch.nn.Module):
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double)
return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer that subtracts self energies from a dataset"""
if self.self_energies is None:
self_energies = self.sae_from_dataset(atomic_properties, properties)
self.self_energies = torch.tensor(self_energies, dtype=torch.double)
species = atomic_properties['species']
energies = properties['energies']
device = energies.device
energies = energies.to(torch.double) - self.sae(species).to(device)
properties['energies'] = energies
return atomic_properties, properties
def forward(self, species_energies: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
......@@ -414,6 +411,17 @@ def get_atomic_masses(species):
return masses
PERIODIC_TABLE = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()
__all__ = ['pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts', 'get_atomic_masses']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment