Unverified Commit 923c8af4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Python2 Inference Support (#171)

parent 1b2faf43
queue:
name: Hosted Ubuntu 1604
timeoutInMinutes: 30
trigger:
batch: true
branches:
include:
- master
variables:
python.version: '2.7'
steps:
- task: UsePythonVersion@0
displayName: 'Use Python $(python.version)'
inputs:
versionSpec: '$(python.version)'
- script: 'azure/install_dependencies.sh && pip install .'
displayName: 'Install dependencies'
- script: 'python2 examples/energy_force.py'
displayName: Energy and Force Example
- script: 'python2 examples/ase_interface.py'
displayName: ASE Interface Example
...@@ -16,6 +16,7 @@ calculator. ...@@ -16,6 +16,7 @@ calculator.
############################################################################### ###############################################################################
# To begin with, let's first import the modules we will use: # To begin with, let's first import the modules we will use:
from __future__ import print_function
from ase.lattice.cubic import Diamond from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase.optimize import BFGS from ase.optimize import BFGS
......
...@@ -9,6 +9,7 @@ TorchANI and can be used directly. ...@@ -9,6 +9,7 @@ TorchANI and can be used directly.
############################################################################### ###############################################################################
# To begin with, let's first import the modules we will use: # To begin with, let's first import the modules we will use:
from __future__ import print_function
import torch import torch
import torchani import torchani
......
...@@ -27,12 +27,14 @@ at :attr:`torchani.ignite`, and more at :attr:`torchani.utils`. ...@@ -27,12 +27,14 @@ at :attr:`torchani.ignite`, and more at :attr:`torchani.utils`.
from .utils import EnergyShifter from .utils import EnergyShifter
from .nn import ANIModel, Ensemble from .nn import ANIModel, Ensemble
from .aev import AEVComputer from .aev import AEVComputer
from . import ignite
from . import utils from . import utils
from . import neurochem from . import neurochem
from . import data
from . import models from . import models
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import get_distribution, DistributionNotFound
import sys
if sys.version_info[0] > 2:
from . import ignite
from . import data
try: try:
__version__ = get_distribution(__name__).version __version__ = get_distribution(__name__).version
......
import math
if not hasattr(math, 'inf'):
math.inf = float('inf')
import torch import torch
import itertools import itertools
from . import _six # noqa:F401
import math import math
from . import utils from . import utils
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
https://wiki.fysik.dtu.dk/ase https://wiki.fysik.dtu.dk/ase
""" """
from __future__ import absolute_import
import math import math
import torch import torch
import ase.neighborlist import ase.neighborlist
...@@ -60,7 +61,7 @@ class NeighborList: ...@@ -60,7 +61,7 @@ class NeighborList:
dtype=coordinates.dtype) dtype=coordinates.dtype)
cell = torch.tensor(self.cell, device=coordinates.device, cell = torch.tensor(self.cell, device=coordinates.device,
dtype=coordinates.dtype) dtype=coordinates.dtype)
D += shift @ cell D += torch.mm(shift, cell)
d = D.norm(2, -1) d = D.norm(2, -1)
neighbor_species1 = [] neighbor_species1 = []
neighbor_distances1 = [] neighbor_distances1 = []
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Helpers for working with ignite.""" """Helpers for working with ignite."""
from __future__ import absolute_import
import torch import torch
from . import utils from . import utils
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric from ignite.metrics import Metric, RootMeanSquaredError
from ignite.metrics import RootMeanSquaredError
from ignite.contrib.metrics.regression import MaximumAbsoluteError from ignite.contrib.metrics.regression import MaximumAbsoluteError
......
...@@ -11,14 +11,16 @@ import itertools ...@@ -11,14 +11,16 @@ import itertools
import ignite import ignite
import math import math
import timeit import timeit
from collections.abc import Mapping from . import _six # noqa:F401
import collections
import sys
from ..nn import ANIModel, Ensemble, Gaussian from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter, ChemicalSymbolsToInts from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer from ..aev import AEVComputer
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
class Constants(Mapping): class Constants(collections.abc.Mapping):
"""NeuroChem constants. Objects of this class can be used as arguments """NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``. to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
...@@ -259,7 +261,7 @@ def load_model_ensemble(species, prefix, count): ...@@ -259,7 +261,7 @@ def load_model_ensemble(species, prefix, count):
return Ensemble(models) return Ensemble(models)
class BuiltinsAbstract: class BuiltinsAbstract(object):
"""Base class for loading ANI neural network from configuration files. """Base class for loading ANI neural network from configuration files.
Arguments: Arguments:
...@@ -377,408 +379,415 @@ def hartree2kcal(x): ...@@ -377,408 +379,415 @@ def hartree2kcal(x):
return 627.509 * x return 627.509 * x
from ..data import BatchedANIDataset # noqa: E402 if sys.version_info[0] > 2:
from ..data import AEVCacheLoader # noqa: E402 from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class Trainer:
"""Train with NeuroChem training configurations.
class Trainer: Arguments:
"""Train with NeuroChem training configurations. filename (str): Input file name
device (:class:`torch.device`): device to train the model
Arguments: tqdm (bool): whether to enable tqdm
filename (str): Input file name tensorboard (str): Directory to store tensorboard log file, set to
device (:class:`torch.device`): device to train the model ``None`` to disable tensorboardX.
tqdm (bool): whether to enable tqdm aev_caching (bool): Whether to use AEV caching.
tensorboard (str): Directory to store tensorboard log file, set to checkpoint_name (str): Name of the checkpoint file, checkpoints
``None`` to disable tensorboardX. will be stored in the network directory with this file name.
aev_caching (bool): Whether to use AEV caching. """
checkpoint_name (str): Name of the checkpoint file, checkpoints will be
stored in the network directory with this file name.
"""
def __init__(self, filename, device=torch.device('cuda'), tqdm=False, def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, aev_caching=False, tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'): checkpoint_name='model.pt'):
self.filename = filename self.filename = filename
self.device = device self.device = device
self.aev_caching = aev_caching self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name self.checkpoint_name = checkpoint_name
if tqdm: if tqdm:
import tqdm import tqdm
self.tqdm = tqdm.tqdm self.tqdm = tqdm.tqdm
else:
self.tqdm = None
if tensorboard is not None:
import tensorboardX
self.tensorboard = tensorboardX.SummaryWriter(log_dir=tensorboard)
self.training_eval_every = 20
else:
self.tensorboard = None
with open(filename, 'r') as f:
if filename.endswith('.yaml') or filename.endswith('.yml'):
network_setup, params = self._parse_yaml(f)
else: else:
network_setup, params = self._parse(f.read()) self.tqdm = None
self._construct(network_setup, params) if tensorboard is not None:
import tensorboardX
def _parse(self, txt): self.tensorboard = tensorboardX.SummaryWriter(
parser = lark.Lark(r''' log_dir=tensorboard)
identifier : CNAME self.training_eval_every = 20
else:
outer_assign : identifier "=" value self.tensorboard = None
params : outer_assign *
inner_assign : identifier "=" value ";"
input_size : "inputsize" "=" INT ";"
layer : "layer" "[" inner_assign * "]"
atom_type : WORD
atom_net : "atom_net" atom_type "$" layer * "$"
network_setup: "network_setup" "{" input_size atom_net * "}"
start: params network_setup params
value : SIGNED_INT
| SIGNED_FLOAT
| STRING_VALUE
STRING_VALUE : ("_"|"-"|"."|"/"|LETTER)("_"|"-"|"."|"/"|LETTER|DIGIT)*
%import common.SIGNED_NUMBER
%import common.LETTER
%import common.WORD
%import common.DIGIT
%import common.INT
%import common.SIGNED_INT
%import common.SIGNED_FLOAT
%import common.CNAME
%import common.WS
%ignore WS
%ignore /!.*/
''')
tree = parser.parse(txt)
class TreeExec(lark.Transformer):
def identifier(self, v):
v = v[0].value
return v
def value(self, v): with open(filename, 'r') as f:
if len(v) == 1: if filename.endswith('.yaml') or filename.endswith('.yml'):
v = v[0] network_setup, params = self._parse_yaml(f)
if v.type == 'STRING_VALUE':
v = v.value
elif v.type == 'SIGNED_INT' or v.type == 'INT':
v = int(v.value)
elif v.type == 'SIGNED_FLOAT' or v.type == 'FLOAT':
v = float(v.value)
else:
raise ValueError('unexpected type')
else: else:
raise ValueError('length of value can only be 1 or 2') network_setup, params = self._parse(f.read())
return v self._construct(network_setup, params)
def outer_assign(self, v): def _parse(self, txt):
name = v[0] parser = lark.Lark(r'''
value = v[1] identifier : CNAME
return name, value
outer_assign : identifier "=" value
params : outer_assign *
inner_assign : identifier "=" value ";"
input_size : "inputsize" "=" INT ";"
layer : "layer" "[" inner_assign * "]"
atom_type : WORD
atom_net : "atom_net" atom_type "$" layer * "$"
network_setup: "network_setup" "{" input_size atom_net * "}"
start: params network_setup params
value : SIGNED_INT
| SIGNED_FLOAT
| STRING_VALUE
STRING_VALUE : ("_"|"-"|"."|"/"|LETTER)("_"|"-"|"."|"/"|LETTER|DIGIT)*
%import common.SIGNED_NUMBER
%import common.LETTER
%import common.WORD
%import common.DIGIT
%import common.INT
%import common.SIGNED_INT
%import common.SIGNED_FLOAT
%import common.CNAME
%import common.WS
%ignore WS
%ignore /!.*/
''') # noqa: E501
tree = parser.parse(txt)
class TreeExec(lark.Transformer):
def identifier(self, v):
v = v[0].value
return v
def value(self, v):
if len(v) == 1:
v = v[0]
if v.type == 'STRING_VALUE':
v = v.value
elif v.type == 'SIGNED_INT' or v.type == 'INT':
v = int(v.value)
elif v.type == 'SIGNED_FLOAT' or v.type == 'FLOAT':
v = float(v.value)
else:
raise ValueError('unexpected type')
else:
raise ValueError('length of value can only be 1 or 2')
return v
inner_assign = outer_assign def outer_assign(self, v):
name = v[0]
value = v[1]
return name, value
def params(self, v): inner_assign = outer_assign
return v
def network_setup(self, v): def params(self, v):
intput_size = int(v[0]) return v
atomic_nets = dict(v[1:])
return intput_size, atomic_nets
def layer(self, v): def network_setup(self, v):
return dict(v) intput_size = int(v[0])
atomic_nets = dict(v[1:])
return intput_size, atomic_nets
def atom_net(self, v): def layer(self, v):
atom_type = v[0] return dict(v)
layers = v[1:]
return atom_type, layers
def atom_type(self, v): def atom_net(self, v):
return v[0].value atom_type = v[0]
layers = v[1:]
return atom_type, layers
def start(self, v): def atom_type(self, v):
network_setup = v[1] return v[0].value
del v[1]
return network_setup, dict(itertools.chain(*v))
def input_size(self, v): def start(self, v):
return v[0].value network_setup = v[1]
del v[1]
return network_setup, dict(itertools.chain(*v))
return TreeExec().transform(tree) def input_size(self, v):
return v[0].value
def _parse_yaml(self, f): return TreeExec().transform(tree)
import yaml
params = yaml.safe_load(f)
network_setup = params['network_setup']
del params['network_setup']
network_setup = (network_setup['inputsize'], network_setup['atom_net'])
return network_setup, params
def _construct(self, network_setup, params): def _parse_yaml(self, f):
dir_ = os.path.dirname(os.path.abspath(self.filename)) import yaml
params = yaml.safe_load(f)
network_setup = params['network_setup']
del params['network_setup']
network_setup = (network_setup['inputsize'],
network_setup['atom_net'])
return network_setup, params
# delete ignored params def _construct(self, network_setup, params):
def del_if_exists(key): dir_ = os.path.dirname(os.path.abspath(self.filename))
if key in params:
del params[key]
def assert_param(key, value): # delete ignored params
if key in params and params[key] != value: def del_if_exists(key):
raise NotImplementedError(key + ' not supported yet') if key in params:
del params[key] del params[key]
del_if_exists('gpuid')
del_if_exists('nkde')
del_if_exists('fmult')
del_if_exists('cmult')
del_if_exists('decrate')
del_if_exists('mu')
assert_param('pbc', 0)
assert_param('force', 0)
assert_param('energy', 1)
assert_param('moment', 'ADAM')
assert_param('runtype', 'ANNP_CREATE_HDNN_AND_TRAIN')
assert_param('adptlrn', 'OFF')
assert_param('tmax', 0)
assert_param('nmax', 0)
assert_param('ntwshr', 0)
# load parameters
self.const_file = os.path.join(dir_, params['sflparamsfile'])
self.consts = Constants(self.const_file)
self.aev_computer = AEVComputer(**self.consts)
del params['sflparamsfile']
self.sae_file = os.path.join(dir_, params['atomEnergyFile'])
self.shift_energy = load_sae(self.sae_file)
del params['atomEnergyFile']
network_dir = os.path.join(dir_, params['ntwkStoreDir'])
if not os.path.exists(network_dir):
os.makedirs(network_dir)
self.model_checkpoint = os.path.join(network_dir, self.checkpoint_name)
del params['ntwkStoreDir']
self.max_nonimprove = params['tolr']
del params['tolr']
self.init_lr = params['eta']
del params['eta']
self.lr_decay = params['emult']
del params['emult']
self.min_lr = params['tcrit']
del params['tcrit']
self.training_batch_size = params['tbtchsz']
del params['tbtchsz']
self.validation_batch_size = params['vbtchsz']
del params['vbtchsz']
# construct networks
input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length():
raise ValueError('AEV size and input size does not match')
l2reg = []
atomic_nets = {}
for atom_type in network_setup:
layers = network_setup[atom_type]
modules = []
i = input_size
for layer in layers:
o = layer['nodes']
del layer['nodes']
if layer['type'] != 0:
raise ValueError('Unsupported layer type')
del layer['type']
module = torch.nn.Linear(i, o)
modules.append(module)
activation = _get_activation(layer['activation'])
if activation is not None:
modules.append(activation)
del layer['activation']
if 'l2norm' in layer:
if layer['l2norm'] == 1:
# NB: The "L2" implemented in NeuroChem is actually not
# L2 but weight decay. The difference of these two is:
# https://arxiv.org/pdf/1711.05101.pdf
# There is a pull request on github/pytorch
# implementing AdamW, etc.:
# https://github.com/pytorch/pytorch/pull/4429
# There is no plan to support the "L2" settings in
# input file before AdamW get merged into pytorch.
raise NotImplementedError('L2 not supported yet')
del layer['l2norm']
del layer['l2valu']
if layer:
raise ValueError('unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s] for s in self.consts.species])
if self.aev_caching:
self.nnp = self.model
else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = Container({'energies': self.nnp}).to(self.device)
# losses
def l2():
return sum([c * (m.weight ** 2).sum() for c, m in l2reg])
self.mse_loss = TransformedLoss(MSELoss('energies'),
lambda x: x + l2())
self.exp_loss = TransformedLoss(
MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2())
if params:
raise ValueError('unrecognized parameter')
self.global_epoch = 0
self.global_iteration = 0
self.best_validation_rmse = math.inf
def evaluate(self, dataset):
"""Evaluate on given dataset to compute RMSE and MAE."""
evaluator = ignite.engine.create_supervised_evaluator(
self.container,
metrics={
'RMSE': RMSEMetric('energies'),
'MAE': MAEMetric('energies'),
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE'])
def load_data(self, training_path, validation_path): def assert_param(key, value):
"""Load training and validation dataset from file. if key in params and params[key] != value:
raise NotImplementedError(key + ' not supported yet')
del params[key]
If AEV caching is enabled, then the arguments are path to the cache del_if_exists('gpuid')
directory, otherwise it should be path to the dataset. del_if_exists('nkde')
""" del_if_exists('fmult')
if self.aev_caching: del_if_exists('cmult')
self.training_set = AEVCacheLoader(training_path) del_if_exists('decrate')
self.validation_set = AEVCacheLoader(validation_path) del_if_exists('mu')
else: assert_param('pbc', 0)
self.training_set = BatchedANIDataset( assert_param('force', 0)
training_path, self.consts.species_to_tensor, assert_param('energy', 1)
self.training_batch_size, device=self.device, assert_param('moment', 'ADAM')
transform=[self.shift_energy.subtract_from_dataset]) assert_param('runtype', 'ANNP_CREATE_HDNN_AND_TRAIN')
self.validation_set = BatchedANIDataset( assert_param('adptlrn', 'OFF')
validation_path, self.consts.species_to_tensor, assert_param('tmax', 0)
self.validation_batch_size, device=self.device, assert_param('nmax', 0)
transform=[self.shift_energy.subtract_from_dataset]) assert_param('ntwshr', 0)
def run(self): # load parameters
"""Run the training""" self.const_file = os.path.join(dir_, params['sflparamsfile'])
start = timeit.default_timer() self.consts = Constants(self.const_file)
self.aev_computer = AEVComputer(**self.consts)
def decorate(trainer): del params['sflparamsfile']
self.sae_file = os.path.join(dir_, params['atomEnergyFile'])
@trainer.on(ignite.engine.Events.STARTED) self.shift_energy = load_sae(self.sae_file)
def initialize(trainer): del params['atomEnergyFile']
trainer.state.no_improve_count = 0 network_dir = os.path.join(dir_, params['ntwkStoreDir'])
trainer.state.epoch += self.global_epoch if not os.path.exists(network_dir):
trainer.state.iteration += self.global_iteration os.makedirs(network_dir)
self.model_checkpoint = os.path.join(network_dir,
@trainer.on(ignite.engine.Events.COMPLETED) self.checkpoint_name)
def finalize(trainer): del params['ntwkStoreDir']
self.global_epoch = trainer.state.epoch self.max_nonimprove = params['tolr']
self.global_iteration = trainer.state.iteration del params['tolr']
self.init_lr = params['eta']
if self.tqdm is not None: del params['eta']
@trainer.on(ignite.engine.Events.EPOCH_STARTED) self.lr_decay = params['emult']
def init_tqdm(trainer): del params['emult']
trainer.state.tqdm = self.tqdm( self.min_lr = params['tcrit']
total=len(self.training_set), desc='epoch') del params['tcrit']
self.training_batch_size = params['tbtchsz']
del params['tbtchsz']
self.validation_batch_size = params['vbtchsz']
del params['vbtchsz']
# construct networks
input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length():
raise ValueError('AEV size and input size does not match')
l2reg = []
atomic_nets = {}
for atom_type in network_setup:
layers = network_setup[atom_type]
modules = []
i = input_size
for layer in layers:
o = layer['nodes']
del layer['nodes']
if layer['type'] != 0:
raise ValueError('Unsupported layer type')
del layer['type']
module = torch.nn.Linear(i, o)
modules.append(module)
activation = _get_activation(layer['activation'])
if activation is not None:
modules.append(activation)
del layer['activation']
if 'l2norm' in layer:
if layer['l2norm'] == 1:
# NB: The "L2" implemented in NeuroChem is actually
# not L2 but weight decay. The difference of these
# two is:
# https://arxiv.org/pdf/1711.05101.pdf
# There is a pull request on github/pytorch
# implementing AdamW, etc.:
# https://github.com/pytorch/pytorch/pull/4429
# There is no plan to support the "L2" settings in
# input file before AdamW get merged into pytorch.
raise NotImplementedError('L2 not supported yet')
del layer['l2norm']
del layer['l2valu']
if layer:
raise ValueError(
'unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s]
for s in self.consts.species])
if self.aev_caching:
self.nnp = self.model
else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = Container({'energies': self.nnp}).to(self.device)
# losses
def l2():
return sum([c * (m.weight ** 2).sum() for c, m in l2reg])
self.mse_loss = TransformedLoss(MSELoss('energies'),
lambda x: x + l2())
self.exp_loss = TransformedLoss(
MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2())
if params:
raise ValueError('unrecognized parameter')
self.global_epoch = 0
self.global_iteration = 0
self.best_validation_rmse = math.inf
def evaluate(self, dataset):
"""Evaluate on given dataset to compute RMSE and MAE."""
evaluator = ignite.engine.create_supervised_evaluator(
self.container,
metrics={
'RMSE': RMSEMetric('energies'),
'MAE': MAEMetric('energies'),
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE'])
def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.
If AEV caching is enabled, then the arguments are path to the cache
directory, otherwise it should be path to the dataset.
"""
if self.aev_caching:
self.training_set = AEVCacheLoader(training_path)
self.validation_set = AEVCacheLoader(validation_path)
else:
self.training_set = BatchedANIDataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = BatchedANIDataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def run(self):
"""Run the training"""
start = timeit.default_timer()
def decorate(trainer):
@trainer.on(ignite.engine.Events.STARTED)
def initialize(trainer):
trainer.state.no_improve_count = 0
trainer.state.epoch += self.global_epoch
trainer.state.iteration += self.global_iteration
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED) @trainer.on(ignite.engine.Events.COMPLETED)
def update_tqdm(trainer): def finalize(trainer):
trainer.state.tqdm.update(1) self.global_epoch = trainer.state.epoch
self.global_iteration = trainer.state.iteration
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED) if self.tqdm is not None:
def finalize_tqdm(trainer): @trainer.on(ignite.engine.Events.EPOCH_STARTED)
trainer.state.tqdm.close() def init_tqdm(trainer):
trainer.state.tqdm = self.tqdm(
total=len(self.training_set), desc='epoch')
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def validation_and_checkpoint(trainer): def update_tqdm(trainer):
trainer.state.rmse, trainer.state.mae = \ trainer.state.tqdm.update(1)
self.evaluate(self.validation_set)
if trainer.state.rmse < self.best_validation_rmse:
trainer.state.no_improve_count = 0
self.best_validation_rmse = trainer.state.rmse
torch.save(self.model.state_dict(), self.model_checkpoint)
else:
trainer.state.no_improve_count += 1
if trainer.state.no_improve_count > self.max_nonimprove: @trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
trainer.terminate() def finalize_tqdm(trainer):
trainer.state.tqdm.close()
if self.tensorboard is not None:
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_per_epoch(trainer): def validation_and_checkpoint(trainer):
elapsed = round(timeit.default_timer() - start, 2) trainer.state.rmse, trainer.state.mae = \
epoch = trainer.state.epoch self.evaluate(self.validation_set)
self.tensorboard.add_scalar('time_vs_epoch', elapsed, if trainer.state.rmse < self.best_validation_rmse:
epoch) trainer.state.no_improve_count = 0
self.tensorboard.add_scalar('learning_rate_vs_epoch', lr, self.best_validation_rmse = trainer.state.rmse
epoch) torch.save(self.model.state_dict(),
self.tensorboard.add_scalar('validation_rmse_vs_epoch', self.model_checkpoint)
trainer.state.rmse, epoch) else:
self.tensorboard.add_scalar('validation_mae_vs_epoch', trainer.state.no_improve_count += 1
trainer.state.mae, epoch)
self.tensorboard.add_scalar( if trainer.state.no_improve_count > self.max_nonimprove:
'best_validation_rmse_vs_epoch', trainer.terminate()
self.best_validation_rmse, epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch', if self.tensorboard is not None:
trainer.state.no_improve_count, @trainer.on(ignite.engine.Events.EPOCH_STARTED)
epoch) def log_per_epoch(trainer):
elapsed = round(timeit.default_timer() - start, 2)
# compute training RMSE and MAE epoch = trainer.state.epoch
if epoch % self.training_eval_every == 1: self.tensorboard.add_scalar('time_vs_epoch', elapsed,
training_rmse, training_mae = \ epoch)
self.evaluate(self.training_set) self.tensorboard.add_scalar('learning_rate_vs_epoch',
self.tensorboard.add_scalar('training_rmse_vs_epoch', lr, epoch)
training_rmse, epoch) self.tensorboard.add_scalar('validation_rmse_vs_epoch',
self.tensorboard.add_scalar('training_mae_vs_epoch', trainer.state.rmse, epoch)
training_mae, epoch) self.tensorboard.add_scalar('validation_mae_vs_epoch',
trainer.state.mae, epoch)
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED) self.tensorboard.add_scalar(
def log_loss(trainer): 'best_validation_rmse_vs_epoch',
iteration = trainer.state.iteration self.best_validation_rmse, epoch)
loss = trainer.state.output self.tensorboard.add_scalar(
self.tensorboard.add_scalar('loss_vs_iteration', 'no_improve_count_vs_epoch',
loss, iteration) trainer.state.no_improve_count, epoch)
lr = self.init_lr # compute training RMSE and MAE
if epoch % self.training_eval_every == 1:
# training using mse loss first until the validation MAE decrease training_rmse, training_mae = \
# to < 1 Hartree self.evaluate(self.training_set)
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.tensorboard.add_scalar(
trainer = ignite.engine.create_supervised_trainer( 'training_rmse_vs_epoch', training_rmse, epoch)
self.container, optimizer, self.mse_loss) self.tensorboard.add_scalar(
decorate(trainer) 'training_mae_vs_epoch', training_mae, epoch)
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def terminate_if_smaller_enough(trainer): def log_loss(trainer):
if trainer.state.mae < 1.0: iteration = trainer.state.iteration
trainer.terminate() loss = trainer.state.output
self.tensorboard.add_scalar('loss_vs_iteration',
trainer.run(self.training_set, max_epochs=math.inf) loss, iteration)
while lr > self.min_lr: lr = self.init_lr
# training using mse loss first until the validation MAE decrease
# to < 1 Hartree
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss) self.container, optimizer, self.mse_loss)
decorate(trainer) decorate(trainer)
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def terminate_if_smaller_enough(trainer):
if trainer.state.mae < 1.0:
trainer.terminate()
trainer.run(self.training_set, max_epochs=math.inf) trainer.run(self.training_set, max_epochs=math.inf)
lr *= self.lr_decay
while lr > self.min_lr:
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
trainer = ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss)
decorate(trainer)
trainer.run(self.training_set, max_epochs=math.inf)
lr *= self.lr_decay
__all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble', __all__ = ['Constants', 'load_sae', 'load_model', 'load_model_ensemble',
......
import collections
if not hasattr(collections, 'abc'):
collections.abc = collections
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment