Commit 2359a387 authored by Xiang Gao's avatar Xiang Gao Committed by Gao, Xiang
Browse files

torchani 0.1

parents
*.h5 filter=lfs diff=lfs merge=lfs -text
*.txt
*.prof
__pycache__
/data
*.cpp
a.out
/test.py
/.vscode
/build
/.eggs
/torchani.egg-info
/*.h5
/*.hdf5
.ipynb_checkpoints
benchmark_xyz
*.pyc
recursive-include torchani/resources *
\ No newline at end of file
1. Add GPU support for torch.unique
2. Add empty tensor support to pytorch
3. meshgrid for pytorch
4. @unittest.skipIf
\ No newline at end of file
from benchmark import Benchmark
import torchani
class ANIBenchmark(Benchmark):
def __init__(self, device):
super(ANIBenchmark, self).__init__(device)
self.aev_computer = torchani.SortedAEV(device=device)
self.model = torchani.ModelOnAEV(
self.aev_computer, benchmark=True, derivative=True, from_nc=None)
def oneByOne(self, coordinates, species):
conformations = coordinates.shape[0]
coordinates = coordinates.to(self.device)
for i in range(conformations):
c = coordinates[i:i+1, :, :]
self.model(c, species)
ret = {
'aev': self.model.timers['aev'],
'energy': self.model.timers['nn'],
'force': self.model.timers['derivative']
}
self.model.reset_timers()
return ret
def inBatch(self, coordinates, species):
coordinates = coordinates.to(self.device)
self.model(coordinates, species)
ret = {
'aev': self.model.timers['aev'],
'energy': self.model.timers['nn'],
'force': self.model.timers['derivative']
}
self.model.reset_timers()
return ret
import numpy
class Benchmark:
"""Abstract class for benchmarking ANI implementations"""
def __init__(self, device):
self.device = device
def oneByOne(self, coordinates, species):
"""Benchmarking the given dataset of computing energies and forces one at a time
Parameters
----------
coordinates : numpy.ndarray
Array of shape (conformations, atoms, 3)
species : list
List of species for this molecule. The length of the list must be the same as
atoms in the molecule.
Returns
-------
dict
Dictionary storing the times for computing AEVs, energies and forces, in seconds.
The dictionary should contain the following keys:
aev : the time used to compute AEVs from coordinates with given neighbor list.
energy : the time used to compute energies, when the AEVs are given.
force : the time used to compute forces, when the energies and AEVs are given.
"""
# return { 'neighborlist': 0, 'aev': 0, 'energy': 0, 'force': 0 }
raise NotImplementedError('subclass must implement this method')
def inBatch(self, coordinates, species):
"""Benchmarking the given dataset of computing energies and forces in batch mode
The signature of this function is the same as `oneByOne`"""
# return { 'neighborlist': 0, 'aev': 0, 'energy': 0, 'force': 0 }
raise NotImplementedError('subclass must implement this method')
from ase import Atoms
from ase.calculators.tip3p import TIP3P, rOH, angleHOH
from ase.md import Langevin
import ase.units as units
from ase.io.trajectory import Trajectory
import numpy
import h5py
from rdkit import Chem
from rdkit.Chem import AllChem
# from asap3 import EMT
from ase.calculators.emt import EMT
from multiprocessing import Pool
from tqdm import tqdm, tqdm_notebook, trange
tqdm.monitor_interval = 0
from selected_system import mols, mol_file
import functools
conformations = 1024
T = 30
fw = h5py.File("waters.hdf5", "w")
fm = h5py.File(mol_file, "w")
def save(h5file, name, species, coordinates):
h5file[name] = coordinates
h5file[name].attrs['species'] = ' '.join(species)
def waterbox(x, y, z, tqdmpos):
name = '{}_waters'.format(x*y*z)
# Set up water box at 20 deg C density
a = angleHOH * numpy.pi / 180 / 2
pos = [[0, 0, 0],
[0, rOH * numpy.cos(a), rOH * numpy.sin(a)],
[0, rOH * numpy.cos(a), -rOH * numpy.sin(a)]]
atoms = Atoms('OH2', positions=pos)
vol = ((18.01528 / 6.022140857e23) / (0.9982 / 1e24))**(1 / 3.)
atoms.set_cell((vol, vol, vol))
atoms.center()
atoms = atoms.repeat((x, y, z))
atoms.set_pbc(False)
species = atoms.get_chemical_symbols()
atoms.calc = TIP3P()
md = Langevin(atoms, 1 * units.fs, temperature=T *
units.kB, friction=0.01)
def generator(n):
for _ in trange(n, desc=name, position=tqdmpos):
md.run(1)
positions = atoms.get_positions()
yield positions
save(fw, name, species, numpy.stack(generator(conformations)))
def compute(smiles):
m = Chem.MolFromSmiles(smiles)
m = Chem.AddHs(m)
AllChem.EmbedMolecule(m, useRandomCoords=True)
AllChem.UFFOptimizeMolecule(m)
pos = m.GetConformer().GetPositions()
natoms = m.GetNumAtoms()
species = [m.GetAtomWithIdx(j).GetSymbol() for j in range(natoms)]
atoms = Atoms(species, positions=pos)
atoms.calc = EMT()
md = Langevin(atoms, 1 * units.fs, temperature=T *
units.kB, friction=0.01)
def generator(n):
for _ in range(n):
md.run(1)
positions = atoms.get_positions()
yield positions
c = numpy.stack(generator(conformations))
return smiles.replace('/', '_'), species, c
def molecules():
smiles = [s for atoms in mols for s in mols[atoms]]
with Pool() as p:
return p.map(compute, smiles)
if __name__ == '__main__':
for i in molecules():
save(fm, *i)
print(list(fm.keys()))
print('done with molecules')
with Pool() as p:
p.starmap(waterbox, [(10, 10, 10, 0), (20, 20, 10,
1), (30, 30, 30, 2), (40, 40, 40, 3)])
print(list(fw.keys()))
print('done with water boxes')
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/gaoxiang/pytorch-dev/pytorchdev/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"WARNING:root:Unable to import NeuroChemAEV, please check your pyNeuroChem installation.\n"
]
}
],
"source": [
"import h5py\n",
"import torch\n",
"from selected_system import mols, mol_file\n",
"from ani_benchmark import ANIBenchmark\n",
"import pandas\n",
"import os\n",
"import tqdm\n",
"from IPython.display import display\n",
"import itertools\n",
"tqdm.monitor_interval = 0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"torch.set_num_threads(1)\n",
"fm = h5py.File(os.path.join('../',mol_file), \"r\")\n",
"\n",
"benchmarks = {\n",
" 'C': ANIBenchmark(device=torch.device(\"cpu\")),\n",
"}\n",
"\n",
"if torch.cuda.is_available():\n",
" benchmarks.update({\n",
" 'G': ANIBenchmark(device=torch.device(\"cuda\")),\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of atoms: 20\n",
"Running benchmark on molecule COC(=O)c1ccc([N+](=O)[O-])cc1\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9ff5dae10b0e4a85b250095cc1c40f64",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>C,1</th>\n",
" <th>C,B</th>\n",
" <th>G,1</th>\n",
" <th>G,B</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>aev</th>\n",
" <td>0.700781</td>\n",
" <td>0.537888</td>\n",
" <td>1.095591</td>\n",
" <td>0.022238</td>\n",
" </tr>\n",
" <tr>\n",
" <th>energy</th>\n",
" <td>0.422604</td>\n",
" <td>0.183495</td>\n",
" <td>0.564654</td>\n",
" <td>0.002251</td>\n",
" </tr>\n",
" <tr>\n",
" <th>force</th>\n",
" <td>1.253797</td>\n",
" <td>0.561175</td>\n",
" <td>1.283622</td>\n",
" <td>0.010574</td>\n",
" </tr>\n",
" <tr>\n",
" <th>forward</th>\n",
" <td>1.123385</td>\n",
" <td>0.721383</td>\n",
" <td>1.660245</td>\n",
" <td>0.024489</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td>2.377182</td>\n",
" <td>1.282558</td>\n",
" <td>2.943866</td>\n",
" <td>0.035064</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" C,1 C,B G,1 G,B\n",
"aev 0.700781 0.537888 1.095591 0.022238\n",
"energy 0.422604 0.183495 0.564654 0.002251\n",
"force 1.253797 0.561175 1.283622 0.010574\n",
"forward 1.123385 0.721383 1.660245 0.024489\n",
"total 2.377182 1.282558 2.943866 0.035064"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of atoms: 50\n",
"Running benchmark on molecule O=[N+]([O-])c1ccc(NN=Cc2ccc(C=NNc3ccc([N+](=O)[O-])cc3[N+](=O)[O-])cc2)c([N+](=O)[O-])c1\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dbb459e25032453e9d6d5a8dc84a08de",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>C,1</th>\n",
" <th>C,B</th>\n",
" <th>G,1</th>\n",
" <th>G,B</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>aev</th>\n",
" <td>1.680758</td>\n",
" <td>2.910254</td>\n",
" <td>1.120638</td>\n",
" <td>0.033106</td>\n",
" </tr>\n",
" <tr>\n",
" <th>energy</th>\n",
" <td>0.546913</td>\n",
" <td>0.406957</td>\n",
" <td>0.474391</td>\n",
" <td>0.002537</td>\n",
" </tr>\n",
" <tr>\n",
" <th>force</th>\n",
" <td>3.637797</td>\n",
" <td>2.189924</td>\n",
" <td>1.495740</td>\n",
" <td>0.036373</td>\n",
" </tr>\n",
" <tr>\n",
" <th>forward</th>\n",
" <td>2.227671</td>\n",
" <td>3.317210</td>\n",
" <td>1.595029</td>\n",
" <td>0.035643</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td>5.865469</td>\n",
" <td>5.507135</td>\n",
" <td>3.090769</td>\n",
" <td>0.072016</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" C,1 C,B G,1 G,B\n",
"aev 1.680758 2.910254 1.120638 0.033106\n",
"energy 0.546913 0.406957 0.474391 0.002537\n",
"force 3.637797 2.189924 1.495740 0.036373\n",
"forward 2.227671 3.317210 1.595029 0.035643\n",
"total 5.865469 5.507135 3.090769 0.072016"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of atoms: 10\n",
"Running benchmark on molecule N#CCC(=O)N\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4e04277c06f5497384bed8e910669212",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>C,1</th>\n",
" <th>C,B</th>\n",
" <th>G,1</th>\n",
" <th>G,B</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>aev</th>\n",
" <td>0.420929</td>\n",
" <td>0.149769</td>\n",
" <td>1.081606</td>\n",
" <td>0.006245</td>\n",
" </tr>\n",
" <tr>\n",
" <th>energy</th>\n",
" <td>0.288124</td>\n",
" <td>0.093577</td>\n",
" <td>0.446744</td>\n",
" <td>0.002521</td>\n",
" </tr>\n",
" <tr>\n",
" <th>force</th>\n",
" <td>0.755610</td>\n",
" <td>0.216397</td>\n",
" <td>1.218433</td>\n",
" <td>0.007300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>forward</th>\n",
" <td>0.709054</td>\n",
" <td>0.243346</td>\n",
" <td>1.528349</td>\n",
" <td>0.008766</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td>1.464663</td>\n",
" <td>0.459743</td>\n",
" <td>2.746782</td>\n",
" <td>0.016067</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" C,1 C,B G,1 G,B\n",
"aev 0.420929 0.149769 1.081606 0.006245\n",
"energy 0.288124 0.093577 0.446744 0.002521\n",
"force 0.755610 0.216397 1.218433 0.007300\n",
"forward 0.709054 0.243346 1.528349 0.008766\n",
"total 1.464663 0.459743 2.746782 0.016067"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of atoms: 4,5,6\n",
"Running benchmark on molecule C\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "65906ab1bdf1470aaf7515189c163f09",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>C,1</th>\n",
" <th>C,B</th>\n",
" <th>G,1</th>\n",
" <th>G,B</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>aev</th>\n",
" <td>0.207580</td>\n",
" <td>0.017271</td>\n",
" <td>0.775032</td>\n",
" <td>0.004428</td>\n",
" </tr>\n",
" <tr>\n",
" <th>energy</th>\n",
" <td>0.149350</td>\n",
" <td>0.048953</td>\n",
" <td>0.219576</td>\n",
" <td>0.001350</td>\n",
" </tr>\n",
" <tr>\n",
" <th>force</th>\n",
" <td>0.390866</td>\n",
" <td>0.121347</td>\n",
" <td>0.768956</td>\n",
" <td>0.005004</td>\n",
" </tr>\n",
" <tr>\n",
" <th>forward</th>\n",
" <td>0.356931</td>\n",
" <td>0.066225</td>\n",
" <td>0.994608</td>\n",
" <td>0.005778</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td>0.747796</td>\n",
" <td>0.187572</td>\n",
" <td>1.763565</td>\n",
" <td>0.010781</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" C,1 C,B G,1 G,B\n",
"aev 0.207580 0.017271 0.775032 0.004428\n",
"energy 0.149350 0.048953 0.219576 0.001350\n",
"force 0.390866 0.121347 0.768956 0.005004\n",
"forward 0.356931 0.066225 0.994608 0.005778\n",
"total 0.747796 0.187572 1.763565 0.010781"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of atoms: 100\n",
"Running benchmark on molecule CC(C)C[C@@H](C(=O)O)NC(=O)C[C@@H]([C@H](CC1CCCCC1)NC(=O)CC[C@@H]([C@H](Cc2ccccc2)NC(=O)OC(C)(C)C)O)O\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bb3a02673cc14b468839c8b3a0e84edc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>C,1</th>\n",
" <th>C,B</th>\n",
" <th>G,1</th>\n",
" <th>G,B</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>aev</th>\n",
" <td>6.933071</td>\n",
" <td>11.386394</td>\n",
" <td>1.226802</td>\n",
" <td>0.118203</td>\n",
" </tr>\n",
" <tr>\n",
" <th>energy</th>\n",
" <td>1.004400</td>\n",
" <td>0.848959</td>\n",
" <td>0.517654</td>\n",
" <td>0.001794</td>\n",
" </tr>\n",
" <tr>\n",
" <th>force</th>\n",
" <td>8.206886</td>\n",
" <td>6.606211</td>\n",
" <td>1.827498</td>\n",
" <td>0.108411</td>\n",
" </tr>\n",
" <tr>\n",
" <th>forward</th>\n",
" <td>7.937472</td>\n",
" <td>12.235354</td>\n",
" <td>1.744456</td>\n",
" <td>0.119997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td>16.144357</td>\n",
" <td>18.841565</td>\n",
" <td>3.571954</td>\n",
" <td>0.228409</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" C,1 C,B G,1 G,B\n",
"aev 6.933071 11.386394 1.226802 0.118203\n",
"energy 1.004400 0.848959 0.517654 0.001794\n",
"force 8.206886 6.606211 1.827498 0.108411\n",
"forward 7.937472 12.235354 1.744456 0.119997\n",
"total 16.144357 18.841565 3.571954 0.228409"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of atoms: 305\n",
"Running benchmark on molecule [H]/N=C(/N)\\NCCC[C@H](C(=O)N[C@H]([C@@H](C)O)C(=O)N[C@H](Cc1ccc(cc1)O)C(=O)NCCCC[C@@H](C(=O)NCCCC[C@@H](C(=O)NCC(=O)O)NC(=O)[C@H](CCCCNC(=O)[C@@H](Cc2ccc(cc2)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)NC(=O)[C@@H](Cc3ccc(cc3)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)NC(=O)[C@@H](Cc4ccc(cc4)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)N\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a01e82e74094812b0a041a95e42ae38",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda runtime error (2) : out of memory at /home/gaoxiang/pytorch/aten/src/THC/generic/THCStorage.cu:58\n",
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>C,1</th>\n",
" <th>C,B</th>\n",
" <th>G,1</th>\n",
" <th>G,B</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>aev</th>\n",
" <td>22.462832</td>\n",
" <td>38.629616</td>\n",
" <td>1.459434</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>energy</th>\n",
" <td>2.698251</td>\n",
" <td>2.419860</td>\n",
" <td>0.525845</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>force</th>\n",
" <td>21.112948</td>\n",
" <td>20.788418</td>\n",
" <td>3.065523</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>forward</th>\n",
" <td>25.161083</td>\n",
" <td>41.049476</td>\n",
" <td>1.985280</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td>46.274031</td>\n",
" <td>61.837895</td>\n",
" <td>5.050802</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" C,1 C,B G,1 G,B\n",
"aev 22.462832 38.629616 1.459434 NaN\n",
"energy 2.698251 2.419860 0.525845 NaN\n",
"force 21.112948 20.788418 3.065523 NaN\n",
"forward 25.161083 41.049476 1.985280 NaN\n",
"total 46.274031 61.837895 5.050802 NaN"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in mols:\n",
" print('number of atoms:', i)\n",
" smiles = mols[i]\n",
" for s in smiles:\n",
" print('Running benchmark on molecule', s)\n",
" key = s.replace('/', '_')\n",
" coordinates = torch.from_numpy(fm[key][()])\n",
" coordinates = coordinates[:200]\n",
" species = fm[key].attrs['species'].split()\n",
" results = {}\n",
" for b,m in tqdm.tqdm_notebook(list(itertools.product(benchmarks, ['1','B']))):\n",
" bench = benchmarks[b]\n",
" coordinates = coordinates.type(bench.aev_computer.dtype)\n",
" try:\n",
" if m == '1':\n",
" result = bench.oneByOne(coordinates, species)\n",
" elif m == 'B':\n",
" result = bench.inBatch(coordinates, species)\n",
" else:\n",
" raise ValueError('BUG here')\n",
" result['forward'] = result['aev'] + result['energy']\n",
" result['total'] = result['forward'] + result['force']\n",
" except RuntimeError as e:\n",
" print(e)\n",
" result = {'aev': None, 'energy': None, 'force': None, 'total': None }\n",
" results[b + ',' + m] = result\n",
" df = pandas.DataFrame(results)\n",
" display(df)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
mols = {
'20': [
'COC(=O)c1ccc([N+](=O)[O-])cc1',
'O=c1nnc2ccccc2n1CO',
'CCc1ccc([N+](=O)[O-])cc1',
'Nc1ccc(c2cnco2)cc1',
'COc1ccc(N)c(N)c1',
'O=C(O)CNc1ccccc1',
'NC(=O)NNc1ccccc1',
'Cn1c(=O)oc(=O)c2ccccc12',
'CC(=O)Nc1ccc(O)cc1',
'COc1ccc(CC#N)cc1'
],
'50': [
'O=[N+]([O-])c1ccc(NN=Cc2ccc(C=NNc3ccc([N+](=O)[O-])cc3[N+](=O)[O-])cc2)c([N+](=O)[O-])c1',
'CCCCCc1nccnc1OCC(C)(C)CC(C)C',
'CC(C)(C)c1ccc(N(C(=O)c2ccccc2)C(C)(C)C)cc1',
'CCCCCCCCCCCOC(=O)Nc1ccccc1',
'CC(=O)NCC(CN1CCCC1)(c1ccccc1)c1ccccc1',
'CCCCCc1cnc(C)c(OCC(C)(C)CCC)n1',
'CCCCCCCCCCCCN1CCOC(=O)C1',
'CCCCOc1ccc(C=Nc2ccc(CCCC)cc2)cc1',
'CC1CC(C)C(=NNC(=O)N)C(C(O)CC2CC(=O)NC(=O)C2)C1',
'CCCCCOc1ccc(C=Nc2ccc(C(=O)OCC)cc2)cc1'
],
'10': [
'N#CCC(=O)N',
'N#CCCO',
'O=C1NC(=O)C(=O)N1',
'COCC#N',
'N#CCNC=O',
'ON=CC=NO',
'NCC(=O)O',
'NC(=O)CO',
'N#Cc1ccco1',
'C=CC(=O)N'
],
'4,5,6': [
'C',
'C#CC#N',
'C=C',
'CC#N',
'C#CC#C',
'O=CC#C',
'C#C'
],
'100': [
'CC(C)C[C@@H](C(=O)O)NC(=O)C[C@@H]([C@H](CC1CCCCC1)NC(=O)CC[C@@H]([C@H](Cc2ccccc2)NC(=O)OC(C)(C)C)O)O',
'CC(C)(C)OC(=O)N[C@@H](Cc1ccccc1)[C@@H](CN[C@@H](Cc2ccccc2)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](Cc3ccccc3)C(=O)N)O',
'CC(C)(C)OC(=O)N[C@@H](Cc1ccccc1)[C@H](CN[C@@H](Cc2ccccc2)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](Cc3ccccc3)C(=O)N)O',
'CC[C@H](c1ccc(cc1)O)[C@H](c2ccc(cc2)O)C(=O)OCCCCCCCCOC(=O)C(c3ccc(cc3)O)C(CC)c4ccc(cc4)O',
'CC/C(=C\\CC[C@H](C)C[C@@H](C)CC[C@@H]([C@H](C)C(=O)C[C@H]([C@H](C)[C@@H](C)OC(=O)C[C@H](/C(=C(\\C)/C(=O)O)/C(=O)O)O)O)O)/C=C/C(=O)O',
'CC[C@H](C)[C@H]1C(=O)NCCCOc2ccc(cc2)C[C@@H](C(=O)N1)NC(=O)[C@@H]3Cc4ccc(cc4)OCCCCC(=O)N[C@H](C(=O)N3)C(C)C',
'CC(C)(C)CC(C)(C)c1ccc(cc1)OCCOCCOCCOCCOCCOCCOCCOCCOCCO',
'CCOC(=O)CC[C@H](C[C@@H]1CCNC1=O)NC(=O)[C@H](Cc2ccccc2)NC(=O)[C@H](CCC(=O)OC(C)(C)C)NC(=O)OCc3ccccc3',
'C[C@]12CC[C@@H]3c4ccc(cc4CC[C@H]3[C@@H]1C[C@@H]([C@@H]2O)CCCCCCCCC(=O)OC[C@@H]5[C@H]([C@H]([C@@H](O5)n6cnc7c6ncnc7N)O)O)O',
'c1cc(ccc1CCc2c[nH]c3c2C(=O)NC(=N3)N)C(=O)N[C@@H](CCC(=O)N[C@@H](CCC(=O)N[C@@H](CCC(=O)N[C@H](CCC(=O)O)C(=O)O)C(=O)O)C(=O)O)C(=O)O'
],
'305': [
'[H]/N=C(/N)\\NCCC[C@H](C(=O)N[C@H]([C@@H](C)O)C(=O)N[C@H](Cc1ccc(cc1)O)C(=O)NCCCC[C@@H](C(=O)NCCCC[C@@H](C(=O)NCC(=O)O)NC(=O)[C@H](CCCCNC(=O)[C@@H](Cc2ccc(cc2)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)NC(=O)[C@@H](Cc3ccc(cc3)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)NC(=O)[C@@H](Cc4ccc(cc4)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\\[H])/N)N)N'
]
}
mol_file = "molecules.hdf5"
from selected_system import mols, mol_file
import h5py
import os
fm = h5py.File(os.path.join(mol_file), "r")
for i in mols:
print('number of atoms:', i)
smiles = mols[i]
for s in smiles:
key = s.replace('/', '_')
filename = i
with open('benchmark_xyz/' + filename + '.xyz', 'w') as fxyz:
coordinates = fm[key][()]
species = fm[key].attrs['species'].split()
conformations = coordinates.shape[0]
atoms = len(species)
for i in range(conformations):
fxyz.write('{}\n{}\n'.format(
atoms, 'smiles:{}\tconformation:{}'.format(s, i)))
for j in range(atoms):
ss = species[j]
xyz = coordinates[i, j, :]
x = xyz[0]
y = xyz[1]
z = xyz[2]
fxyz.write('{} {} {} {}\n'.format(ss, x, y, z))
break
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = 'torchani'
copyright = '2018, Xiang Gao'
author = 'Xiang Gao'
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = ''
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
'sphinx.ext.napoleon',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['sphinx_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.md'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['sphinx_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'torchanidoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'torchani.tex', 'torchani Documentation',
'Xiang Gao', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchani', 'torchani Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchani', 'torchani Documentation',
author, 'torchani', 'One line description of project.',
'Miscellaneous'),
]
# -- Extension configuration -------------------------------------------------
# -- Options for intersphinx extension ---------------------------------------
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'https://docs.python.org/': None}
# -- Options for todo extension ----------------------------------------------
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
autoclass_content = 'both'
import torch
import torchani
import torchani.data
import math
import timeit
import itertools
import os
import sys
import pickle
from tensorboardX import SummaryWriter
from tqdm import tqdm
from common import *
import sys
import json
chunk_size = 256
batch_chunks = 1024 // chunk_size
with open('data/dataset.dat', 'rb') as f:
training, validation, testing = pickle.load(f)
training_sampler = torchani.data.BatchSampler(
training, chunk_size, batch_chunks)
validation_sampler = torchani.data.BatchSampler(
validation, chunk_size, batch_chunks)
testing_sampler = torchani.data.BatchSampler(
testing, chunk_size, batch_chunks)
training_dataloader = torch.utils.data.DataLoader(
training, batch_sampler=training_sampler, collate_fn=torchani.data.collate)
validation_dataloader = torch.utils.data.DataLoader(
validation, batch_sampler=validation_sampler, collate_fn=torchani.data.collate)
testing_dataloader = torch.utils.data.DataLoader(
testing, batch_sampler=testing_sampler, collate_fn=torchani.data.collate)
writer = SummaryWriter('runs/adam-{}'.format(sys.argv[1]))
checkpoint = 'checkpoint.pt'
model = get_or_create_model(checkpoint)
optimizer = torch.optim.Adam(model.parameters(), **json.loads(sys.argv[1]))
step = 0
epoch = 0
def subset_rmse(subset_dataloader):
a = Averager()
for batch in subset_dataloader:
for molecule_id in batch:
_species = subset_dataloader.dataset.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device)
energies = energies.to(aev_computer.device)
count, squared_error = evaluate(model, coordinates, energies, _species)
squared_error = squared_error.item()
a.add(count, squared_error)
mse = a.avg()
rmse = math.sqrt(mse) * 627.509
return rmse
def optimize_step(a):
mse = a.avg()
rmse = math.sqrt(mse.item()) * 627.509
writer.add_scalar('training_rmse_vs_step', rmse, step)
loss = mse if epoch < 10 else 0.5 * torch.exp(2 * mse)
optimizer.zero_grad()
loss.backward()
optimizer.step()
best_validation_rmse = math.inf
best_epoch = 0
start = timeit.default_timer()
while True:
for batch in tqdm(training_dataloader, desc='epoch {}'.format(epoch), total=len(training_sampler)):
a = Averager()
for molecule_id in batch:
_species = training.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device)
energies = energies.to(aev_computer.device)
count, squared_error = evaluate(model, coordinates, energies, _species)
a.add(count, squared_error / len(_species))
optimize_step(a)
step += 1
validation_rmse = subset_rmse(validation_dataloader)
elapsed = round(timeit.default_timer() - start, 2)
print('Epoch:', epoch, 'time:', elapsed,
'validation rmse:', validation_rmse)
writer.add_scalar('validation_rmse_vs_epoch', validation_rmse, epoch)
writer.add_scalar('epoch_vs_step', epoch, step)
writer.add_scalar('time_vs_epoch', elapsed, epoch)
if validation_rmse < best_validation_rmse:
best_validation_rmse = validation_rmse
best_epoch = epoch
writer.add_scalar('best_validation_rmse_vs_epoch',
best_validation_rmse, best_epoch)
elif epoch - best_epoch > 1000:
print('Stop at best validation rmse:', best_validation_rmse)
break
epoch += 1
testing_rmse = subset_rmse(testing_dataloader)
print('Test rmse:', validation_rmse)
import torch
import torchani
import torchani.data
import math
import timeit
import itertools
import os
import sys
import pickle
from tensorboardX import SummaryWriter
from tqdm import tqdm
from common import *
from copy import deepcopy
chunk_size = 256
batch_chunks = 1024 // chunk_size
with open('data/dataset.dat', 'rb') as f:
training, validation, testing = pickle.load(f)
training_sampler = torchani.data.BatchSampler(
training, chunk_size, batch_chunks)
validation_sampler = torchani.data.BatchSampler(
validation, chunk_size, batch_chunks)
testing_sampler = torchani.data.BatchSampler(
testing, chunk_size, batch_chunks)
training_dataloader = torch.utils.data.DataLoader(
training, batch_sampler=training_sampler, collate_fn=torchani.data.collate)
validation_dataloader = torch.utils.data.DataLoader(
validation, batch_sampler=validation_sampler, collate_fn=torchani.data.collate)
testing_dataloader = torch.utils.data.DataLoader(
testing, batch_sampler=testing_sampler, collate_fn=torchani.data.collate)
writer = SummaryWriter()
checkpoint = 'checkpoint.pt'
model = get_or_create_model(checkpoint)
optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)
step = 0
epoch = 0
def subset_rmse(subset_dataloader):
a = Averager()
for batch in subset_dataloader:
for molecule_id in batch:
_species = subset_dataloader.dataset.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device)
energies = energies.to(aev_computer.device)
count, squared_error = evaluate(coordinates, energies, _species)
squared_error = squared_error.item()
a.add(count, squared_error)
mse = a.avg()
rmse = math.sqrt(mse) * 627.509
return rmse
def optimize_step(a):
mse = a.avg()
rmse = math.sqrt(mse.item()) * 627.509
writer.add_scalar('training_rmse_vs_step', rmse, step)
loss = mse if epoch < 10 else 0.5 * torch.exp(2 * mse)
optimizer.zero_grad()
loss.backward()
optimizer.step()
best_validation_rmse = math.inf
best_epoch = 0
start = timeit.default_timer()
while True:
for batch in tqdm(training_dataloader, desc='epoch {}'.format(epoch), total=len(training_sampler)):
a = Averager()
for molecule_id in batch:
_species = training.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device)
energies = energies.to(aev_computer.device)
count, squared_error = evaluate(model, coordinates, energies, _species)
a.add(count, squared_error / len(_species))
optimize_step(a)
step += 1
validation_rmse = subset_rmse(validation_dataloader)
elapsed = round(timeit.default_timer() - start, 2)
print('Epoch:', epoch, 'time:', elapsed,
'validation rmse:', validation_rmse)
writer.add_scalar('validation_rmse_vs_epoch', validation_rmse, epoch)
writer.add_scalar('epoch_vs_step', epoch, step)
writer.add_scalar('time_vs_epoch', elapsed, epoch)
if validation_rmse < best_validation_rmse:
best_validation_rmse = validation_rmse
best_epoch = epoch
writer.add_scalar('best_validation_rmse_vs_epoch',
best_validation_rmse, best_epoch)
torch.save(model.state_dict(), checkpoint)
elif epoch - best_epoch > 1000:
print('Stop at best validation rmse:', best_validation_rmse)
break
epoch += 1
testing_rmse = subset_rmse(testing_dataloader)
print('Test rmse:', validation_rmse)
import pickle
import torch
hyperparams = [ # (chunk size, batch chunks)
# (64, 4),
(64, 8),
(64, 16),
(64, 32),
(128, 2),
(128, 4),
(128, 8),
(128, 16),
(256, 1),
(256, 2),
(256, 4),
(256, 8),
(512, 1),
(512, 2),
(512, 4),
(1024, 1),
(1024, 2),
(2048, 1),
]
for chunk_size, batch_chunks in hyperparams:
with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks), 'rb') as f:
ag, agsqr = pickle.load(f)
variance = torch.sum(agsqr) - torch.sum(ag**2)
stddev = torch.sqrt(variance).item()
print(chunk_size, batch_chunks, stddev)
import sys
import torch
import torchani
import configs
import torchani.data
import math
from tqdm import tqdm
import itertools
import os
import pickle
if len(sys.argv) >= 2:
configs.device = torch.device(sys.argv[1])
from common import *
ds = torchani.data.load_dataset(configs.data_path)
# just to conveniently zero grads
optimizer = torch.optim.Adam(model.parameters())
def grad_or_zero(parameter):
if parameter.grad is not None:
return parameter.grad.reshape(-1)
else:
return torch.zeros_like(parameter.reshape(-1))
def batch_gradient(batch):
a = Averager()
for molecule_id in batch:
_species = ds.species[molecule_id]
coordinates, energies = batch[molecule_id]
coordinates = coordinates.to(aev_computer.device)
energies = energies.to(aev_computer.device)
a.add(*evaluate(coordinates, energies, _species))
mse = a.avg()
optimizer.zero_grad()
mse.backward()
grads = [grad_or_zero(p) for p in model.parameters()]
grads = torch.cat(grads)
return grads
def compute(chunk_size, batch_chunks):
sampler = torchani.data.BatchSampler(ds, chunk_size, batch_chunks)
dataloader = torch.utils.data.DataLoader(
ds, batch_sampler=sampler, collate_fn=torchani.data.collate)
model_file = 'data/model.pt'
model.load_state_dict(torch.load(
model_file, map_location=lambda storage, loc: storage))
ag = Averager() # avg(grad)
agsqr = Averager() # avg(grad^2)
for batch in tqdm(dataloader, total=len(sampler)):
g = batch_gradient(batch)
ag.add(1, g)
agsqr.add(1, g**2)
ag = ag.avg()
agsqr = agsqr.avg()
with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks), 'wb') as f:
pickle.dump((ag, agsqr), f)
chunk_size = int(sys.argv[2])
batch_chunks = int(sys.argv[3])
compute(chunk_size, batch_chunks)
# for chunk_size, batch_chunks in hyperparams:
# compute(chunk_size, batch_chunks)
import torchani
import torch
import os
from configs import benchmark, device
class Averager:
def __init__(self):
self.count = 0
self.subtotal = 0
def add(self, count, subtotal):
self.count += count
self.subtotal += subtotal
def avg(self):
return self.subtotal / self.count
aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
def celu(x, alpha):
return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1))
class AtomicNetwork(torch.nn.Module):
def __init__(self):
super(AtomicNetwork, self).__init__()
self.output_length = 1
self.layer1 = torch.nn.Linear(384, 128).type(
aev_computer.dtype).to(aev_computer.device)
self.layer2 = torch.nn.Linear(128, 128).type(
aev_computer.dtype).to(aev_computer.device)
self.layer3 = torch.nn.Linear(128, 64).type(
aev_computer.dtype).to(aev_computer.device)
self.layer4 = torch.nn.Linear(64, 1).type(
aev_computer.dtype).to(aev_computer.device)
def forward(self, aev):
y = aev
y = self.layer1(y)
y = celu(y, 0.1)
y = self.layer2(y)
y = celu(y, 0.1)
y = self.layer3(y)
y = celu(y, 0.1)
y = self.layer4(y)
return y
def get_or_create_model(filename):
model = torchani.ModelOnAEV(
aev_computer,
reducer=torch.sum,
benchmark=benchmark,
per_species={
'C': AtomicNetwork(),
'H': AtomicNetwork(),
'N': AtomicNetwork(),
'O': AtomicNetwork(),
})
if os.path.isfile(filename):
model.load_state_dict(torch.load(filename))
else:
torch.save(model.state_dict(), filename)
return model
energy_shifter = torchani.EnergyShifter()
loss = torch.nn.MSELoss(size_average=False)
def evaluate(model, coordinates, energies, species):
count = coordinates.shape[0]
pred = model(coordinates, species).squeeze()
pred = energy_shifter.add_sae(pred, species)
squared_error = loss(pred, energies)
return count, squared_error
import torch
benchmark = False
data_path = 'data/ANI-1x_complete.h5'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torchani.data
import pickle
from configs import data_path
chunk_size = 64
dataset = torchani.data.load_dataset(data_path)
chunks = len(torchani.data.BatchSampler(dataset, chunk_size, 1))
print(chunks, 'chunks')
training_size = int(chunks*0.8)
validation_size = int(chunks*0.1)
testing_size = chunks - training_size - validation_size
training, validation, testing = torchani.data.random_split(
dataset, [training_size, validation_size, testing_size], chunk_size)
with open('data/dataset.dat', 'wb') as f:
pickle.dump((training, validation, testing), f)
import torch
import torchani
device = torch.device('cpu')
const_file = '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params'
sae_file = '../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat'
network_dir = '../torchani/resources/ani-1x_dft_x8ens/train'
aev_computer = torchani.SortedAEV(const_file=const_file, device=device)
nn = torchani.ModelOnAEV(aev_computer, derivative=True,
from_nc=network_dir, ensemble=8)
shift_energy = torchani.EnergyShifter(sae_file)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
dtype=aev_computer.dtype, device=aev_computer.device)
species = ['C', 'H', 'H', 'H', 'H']
energy, derivative = nn(coordinates, species)
energy = shift_energy.add_sae(energy, species)
force = -derivative
print('Energy:', energy.item())
print('Force:', force.squeeze().numpy())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment