Unverified Commit 923c8af4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Python2 Inference Support (#171)

parent 1b2faf43
queue:
name: Hosted Ubuntu 1604
timeoutInMinutes: 30
trigger:
batch: true
branches:
include:
- master
variables:
python.version: '2.7'
steps:
- task: UsePythonVersion@0
displayName: 'Use Python $(python.version)'
inputs:
versionSpec: '$(python.version)'
- script: 'azure/install_dependencies.sh && pip install .'
displayName: 'Install dependencies'
- script: 'python2 examples/energy_force.py'
displayName: Energy and Force Example
- script: 'python2 examples/ase_interface.py'
displayName: ASE Interface Example
......@@ -16,6 +16,7 @@ calculator.
###############################################################################
# To begin with, let's first import the modules we will use:
from __future__ import print_function
from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin
from ase.optimize import BFGS
......
......@@ -9,6 +9,7 @@ TorchANI and can be used directly.
###############################################################################
# To begin with, let's first import the modules we will use:
from __future__ import print_function
import torch
import torchani
......
......@@ -27,12 +27,14 @@ at :attr:`torchani.ignite`, and more at :attr:`torchani.utils`.
from .utils import EnergyShifter
from .nn import ANIModel, Ensemble
from .aev import AEVComputer
from . import ignite
from . import utils
from . import neurochem
from . import data
from . import models
from pkg_resources import get_distribution, DistributionNotFound
import sys
if sys.version_info[0] > 2:
from . import ignite
from . import data
try:
__version__ = get_distribution(__name__).version
......
import math
if not hasattr(math, 'inf'):
math.inf = float('inf')
import torch
import itertools
from . import _six # noqa:F401
import math
from . import utils
......
......@@ -5,6 +5,7 @@
https://wiki.fysik.dtu.dk/ase
"""
from __future__ import absolute_import
import math
import torch
import ase.neighborlist
......@@ -60,7 +61,7 @@ class NeighborList:
dtype=coordinates.dtype)
cell = torch.tensor(self.cell, device=coordinates.device,
dtype=coordinates.dtype)
D += shift @ cell
D += torch.mm(shift, cell)
d = D.norm(2, -1)
neighbor_species1 = []
neighbor_distances1 = []
......
# -*- coding: utf-8 -*-
"""Helpers for working with ignite."""
from __future__ import absolute_import
import torch
from . import utils
from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric
from ignite.metrics import RootMeanSquaredError
from ignite.metrics import Metric, RootMeanSquaredError
from ignite.contrib.metrics.regression import MaximumAbsoluteError
......
......@@ -11,14 +11,16 @@ import itertools
import ignite
import math
import timeit
from collections.abc import Mapping
from . import _six # noqa:F401
import collections
import sys
from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
class Constants(Mapping):
class Constants(collections.abc.Mapping):
"""NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
......@@ -259,7 +261,7 @@ def load_model_ensemble(species, prefix, count):
return Ensemble(models)
class BuiltinsAbstract:
class BuiltinsAbstract(object):
"""Base class for loading ANI neural network from configuration files.
Arguments:
......@@ -377,11 +379,11 @@ def hartree2kcal(x):
return 627.509 * x
from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
if sys.version_info[0] > 2:
from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class Trainer:
class Trainer:
"""Train with NeuroChem training configurations.
Arguments:
......@@ -391,8 +393,8 @@ class Trainer:
tensorboard (str): Directory to store tensorboard log file, set to
``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching.
checkpoint_name (str): Name of the checkpoint file, checkpoints will be
stored in the network directory with this file name.
checkpoint_name (str): Name of the checkpoint file, checkpoints
will be stored in the network directory with this file name.
"""
def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
......@@ -409,7 +411,8 @@ class Trainer:
self.tqdm = None
if tensorboard is not None:
import tensorboardX
self.tensorboard = tensorboardX.SummaryWriter(log_dir=tensorboard)
self.tensorboard = tensorboardX.SummaryWriter(
log_dir=tensorboard)
self.training_eval_every = 20
else:
self.tensorboard = None
......@@ -458,7 +461,7 @@ class Trainer:
%import common.WS
%ignore WS
%ignore /!.*/
''')
''') # noqa: E501
tree = parser.parse(txt)
class TreeExec(lark.Transformer):
......@@ -523,7 +526,8 @@ class Trainer:
params = yaml.safe_load(f)
network_setup = params['network_setup']
del params['network_setup']
network_setup = (network_setup['inputsize'], network_setup['atom_net'])
network_setup = (network_setup['inputsize'],
network_setup['atom_net'])
return network_setup, params
def _construct(self, network_setup, params):
......@@ -566,7 +570,8 @@ class Trainer:
network_dir = os.path.join(dir_, params['ntwkStoreDir'])
if not os.path.exists(network_dir):
os.makedirs(network_dir)
self.model_checkpoint = os.path.join(network_dir, self.checkpoint_name)
self.model_checkpoint = os.path.join(network_dir,
self.checkpoint_name)
del params['ntwkStoreDir']
self.max_nonimprove = params['tolr']
del params['tolr']
......@@ -605,8 +610,9 @@ class Trainer:
del layer['activation']
if 'l2norm' in layer:
if layer['l2norm'] == 1:
# NB: The "L2" implemented in NeuroChem is actually not
# L2 but weight decay. The difference of these two is:
# NB: The "L2" implemented in NeuroChem is actually
# not L2 but weight decay. The difference of these
# two is:
# https://arxiv.org/pdf/1711.05101.pdf
# There is a pull request on github/pytorch
# implementing AdamW, etc.:
......@@ -617,10 +623,12 @@ class Trainer:
del layer['l2norm']
del layer['l2valu']
if layer:
raise ValueError('unrecognized parameter in layer setup')
raise ValueError(
'unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s] for s in self.consts.species])
self.model = ANIModel([atomic_nets[s]
for s in self.consts.species])
if self.aev_caching:
self.nnp = self.model
else:
......@@ -713,7 +721,8 @@ class Trainer:
if trainer.state.rmse < self.best_validation_rmse:
trainer.state.no_improve_count = 0
self.best_validation_rmse = trainer.state.rmse
torch.save(self.model.state_dict(), self.model_checkpoint)
torch.save(self.model.state_dict(),
self.model_checkpoint)
else:
trainer.state.no_improve_count += 1
......@@ -727,8 +736,8 @@ class Trainer:
epoch = trainer.state.epoch
self.tensorboard.add_scalar('time_vs_epoch', elapsed,
epoch)
self.tensorboard.add_scalar('learning_rate_vs_epoch', lr,
epoch)
self.tensorboard.add_scalar('learning_rate_vs_epoch',
lr, epoch)
self.tensorboard.add_scalar('validation_rmse_vs_epoch',
trainer.state.rmse, epoch)
self.tensorboard.add_scalar('validation_mae_vs_epoch',
......@@ -736,18 +745,18 @@ class Trainer:
self.tensorboard.add_scalar(
'best_validation_rmse_vs_epoch',
self.best_validation_rmse, epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch',
trainer.state.no_improve_count,
epoch)
self.tensorboard.add_scalar(
'no_improve_count_vs_epoch',
trainer.state.no_improve_count, epoch)
# compute training RMSE and MAE
if epoch % self.training_eval_every == 1:
training_rmse, training_mae = \
self.evaluate(self.training_set)
self.tensorboard.add_scalar('training_rmse_vs_epoch',
training_rmse, epoch)
self.tensorboard.add_scalar('training_mae_vs_epoch',
training_mae, epoch)
self.tensorboard.add_scalar(
'training_rmse_vs_epoch', training_rmse, epoch)
self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_mae, epoch)
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
......
import collections
if not hasattr(collections, 'abc'):
collections.abc = collections
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment