Unverified Commit 2ad4126f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add AdamW optimizer, Add docs of nnp training to train a real ANI-1x model,...

Add AdamW optimizer, Add docs of nnp training to train a real ANI-1x model, remove h5py and ignite from dependencies, rename MAEMetric -> MaxAEMetric, use torch.utils.tensorboard to replace tensorboardX (#224)
parent 54ab56ee
...@@ -23,7 +23,7 @@ steps: ...@@ -23,7 +23,7 @@ steps:
- script: 'git describe --exact-match --tags HEAD' - script: 'git describe --exact-match --tags HEAD'
displayName: 'Fail build on non-release commits' displayName: 'Fail build on non-release commits'
- script: 'azure/install_dependencies.sh && pip install sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery && pip install .' - script: 'azure/install_dependencies.sh && pip install h5py pytorch-ignite-nightly tb-nightly sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery && pip install .'
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: 'sphinx-build docs build' - script: 'sphinx-build docs build'
......
...@@ -18,7 +18,7 @@ steps: ...@@ -18,7 +18,7 @@ steps:
inputs: inputs:
versionSpec: '$(python.version)' versionSpec: '$(python.version)'
- script: 'azure/install_dependencies.sh && pip install sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery && pip install .' - script: 'azure/install_dependencies.sh && pip install h5py pytorch-ignite-nightly tb-nightly sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery && pip install .'
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: 'sphinx-build docs build' - script: 'sphinx-build docs build'
......
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install tqdm ase tensorboardX pyyaml pip install tqdm ase pyyaml future
...@@ -19,7 +19,7 @@ steps: ...@@ -19,7 +19,7 @@ steps:
inputs: inputs:
versionSpec: '$(python.version)' versionSpec: '$(python.version)'
- script: 'azure/install_dependencies.sh && pip install future && pip install .' - script: 'azure/install_dependencies.sh && pip install .'
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: 'python2 examples/energy_force.py' - script: 'python2 examples/energy_force.py'
......
...@@ -18,14 +18,17 @@ steps: ...@@ -18,14 +18,17 @@ steps:
inputs: inputs:
versionSpec: '$(python.version)' versionSpec: '$(python.version)'
- script: 'azure/install_dependencies.sh && pip install .' - script: 'azure/install_dependencies.sh && pip install h5py .'
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: 'python -m torchani.data.cache_aev tmp dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 256'
displayName: Cache AEV
- script: 'pip install pytorch-ignite-nightly'
displayName: 'Install more dependencies'
- script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.ipt dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5' - script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.ipt dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5'
displayName: NeuroChem Trainer displayName: NeuroChem Trainer
- script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.yaml dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5' - script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.yaml dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5'
displayName: NeuroChem Trainer YAML config displayName: NeuroChem Trainer YAML config
- script: 'python -m torchani.data.cache_aev tmp dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 256'
displayName: Cache AEV
...@@ -25,14 +25,20 @@ steps: ...@@ -25,14 +25,20 @@ steps:
- script: 'azure/install_dependencies.sh && pip install .' - script: 'azure/install_dependencies.sh && pip install .'
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: 'python tools/training-benchmark.py ./dataset/ani1-up_to_gdb4/ani_gdb_s01.h5'
displayName: Training Benchmark
- script: 'python tools/neurochem-test.py ./dataset/ani1-up_to_gdb4/ani_gdb_s01.h5'
displayName: NeuroChem Test
- script: 'python tools/inference-benchmark.py --tqdm ./dataset/xyz_files/CH4-5.xyz' - script: 'python tools/inference-benchmark.py --tqdm ./dataset/xyz_files/CH4-5.xyz'
displayName: Inference Benchmark displayName: Inference Benchmark
- script: 'pip install h5py'
displayName: 'Install more dependencies'
- script: 'python tools/comp6.py ./dataset/COMP6/COMP6v1/s66x8' - script: 'python tools/comp6.py ./dataset/COMP6/COMP6v1/s66x8'
displayName: COMP6 Benchmark displayName: COMP6 Benchmark
- script: 'pip install pytorch-ignite-nightly'
displayName: 'Install more dependencies'
- script: 'python tools/training-benchmark.py ./dataset/ani1-up_to_gdb4/ani_gdb_s01.h5'
displayName: Training Benchmark
- script: 'python tools/neurochem-test.py ./dataset/ani1-up_to_gdb4/ani_gdb_s01.h5'
displayName: NeuroChem Test
...@@ -66,6 +66,14 @@ ASE Interface ...@@ -66,6 +66,14 @@ ASE Interface
.. automodule:: torchani.ase .. automodule:: torchani.ase
.. autoclass:: torchani.ase.Calculator .. autoclass:: torchani.ase.Calculator
TorchANI Optimizater
====================
.. automodule:: torchani.optim
.. autoclass:: torchani.optim.AdamW
Ignite Helpers Ignite Helpers
============== ==============
...@@ -78,4 +86,4 @@ Ignite Helpers ...@@ -78,4 +86,4 @@ Ignite Helpers
.. autofunction:: torchani.ignite.MSELoss .. autofunction:: torchani.ignite.MSELoss
.. autoclass:: torchani.ignite.DictMetric .. autoclass:: torchani.ignite.DictMetric
.. autofunction:: torchani.ignite.RMSEMetric .. autofunction:: torchani.ignite.RMSEMetric
.. autofunction:: torchani.ignite.MAEMetric .. autofunction:: torchani.ignite.MaxAEMetric
...@@ -19,6 +19,7 @@ Welcome to TorchANI's documentation! ...@@ -19,6 +19,7 @@ Welcome to TorchANI's documentation!
examples/vibration_analysis examples/vibration_analysis
examples/load_from_neurochem examples/load_from_neurochem
examples/nnp_training examples/nnp_training
examples/nnp_training_ignite
examples/cache_aev examples/cache_aev
examples/neurochem_trainer examples/neurochem_trainer
......
...@@ -16,9 +16,9 @@ import torch ...@@ -16,9 +16,9 @@ import torch
import ignite import ignite
import torchani import torchani
import timeit import timeit
import tensorboardX
import os import os
import ignite.contrib.handlers import ignite.contrib.handlers
import torch.utils.tensorboard
# training and validation set # training and validation set
...@@ -98,7 +98,7 @@ model = nn.to(device) ...@@ -98,7 +98,7 @@ model = nn.to(device)
############################################################################### ###############################################################################
# This part is also a line by line copy # This part is also a line by line copy
writer = tensorboardX.SummaryWriter(log_dir=log) writer = torch.utils.tensorboard.SummaryWriter(log_dir=log)
############################################################################### ###############################################################################
# Here we don't need to construct :class:`torchani.data.BatchedANIDataset` # Here we don't need to construct :class:`torchani.data.BatchedANIDataset`
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
.. _neurochem-training:
Train Neural Network Potential From NeuroChem Input File Train Neural Network Potential From NeuroChem Input File
======================================================== ========================================================
......
...@@ -5,116 +5,93 @@ ...@@ -5,116 +5,93 @@
Train Your Own Neural Network Potential Train Your Own Neural Network Potential
======================================= =======================================
This example shows how to use TorchANI train your own neural network potential. This example shows how to use TorchANI to train a neural network potential. We
will use the same configuration as specified as in `inputtrain.ipt`_
.. _`inputtrain.ipt`:
https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/inputtrain.ipt
.. note::
TorchANI provide tools to run NeuroChem training config file `inputtrain.ipt`.
See: :ref:`neurochem-training`.
""" """
############################################################################### ###############################################################################
# To begin with, let's first import the modules we will use: # To begin with, let's first import the modules and setup devices we will use:
import torch import torch
import ignite
import torchani import torchani
import timeit
import tensorboardX
import os import os
import ignite.contrib.handlers import math
import torch.utils.tensorboard
import tqdm
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###############################################################################
# Now let's setup constants and construct an AEV computer. These numbers could
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`_ and `sae_linfit.dat`_
#
# .. note::
#
# Besides defining these hyperparameters programmatically,
# :mod:`torchani.neurochem` provide tools to read them from file. See also
# :ref:`training-example-ignite` for an example of usage.
#
# .. _rHCNO-5.2R_16-3.5A_a4-8.params:
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
# .. _sae_linfit.dat:
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/sae_linfit.dat
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter([
-0.600952980000, # H
-38.08316124000, # C
-54.70775770000, # N
-75.19446356000, # O
])
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
############################################################################### ###############################################################################
# Now let's setup training hyperparameters. Note that here for our demo purpose # Now let's setup datasets. Note that here for our demo purpose, we set both
# , we set both training set and validation set the ``ani_gdb_s01.h5`` in # training set and validation set the ``ani_gdb_s01.h5`` in TorchANI's
# TorchANI's repository. This allows this program to finish very quick, because # repository. This allows this program to finish very quick, because that
# that dataset is very small. But this is wrong and should be avoided for any # dataset is very small. But this is wrong and should be avoided for any
# serious training. These paths assumes the user run this script under the # serious training. These paths assumes the user run this script under the
# ``examples`` directory of TorchANI's repository. If you download this script, # ``examples`` directory of TorchANI's repository. If you download this script,
# you should manually set the path of these files in your system before this # you should manually set the path of these files in your system before this
# script can run successfully. # script can run successfully.
#
# Also note that we need to subtracting energies by the self energies of all
# atoms for each molecule. This makes the range of energies in a reasonable
# range. The second argument defines how to convert species as a list of string
# to tensor, that is, for all supported chemical symbols, which is correspond to
# ``0``, which correspond to ``1``, etc.
# training and validation set
try: try:
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
except NameError: except NameError:
path = os.getcwd() path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') # noqa: E501 validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
# checkpoint file to save model when validation RMSE improves
model_checkpoint = 'model.pt'
# max epochs to run the training
max_epochs = 20
# Compute training RMSE every this steps. Since the training set is usually
# huge and the loss funcition does not directly gives us RMSE, we need to
# check the training RMSE to see overfitting.
training_rmse_every = 5
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# batch size
batch_size = 1024
# log directory for tensorboardX
log = 'runs'
############################################################################### batch_size = 2560
# Now let's read our constants and self energies from constant files and
# construct AEV computer.
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
sae_file = os.path.join(path, '../torchani/resources/ani-1x_8x/sae_linfit.dat') # noqa: E501
consts = torchani.neurochem.Constants(const_file)
aev_computer = torchani.AEVComputer(**consts)
energy_shifter = torchani.neurochem.load_sae(sae_file)
###############################################################################
# Now let's define atomic neural networks. Here in this demo, we use the same
# size of neural network for all atom types, but this is not necessary.
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
nn = torchani.ANIModel([atomic() for _ in range(4)])
print(nn)
###############################################################################
# If checkpoint from previous training exists, then load it.
if os.path.isfile(model_checkpoint):
nn.load_state_dict(torch.load(model_checkpoint))
else:
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now setup tensorboardX.
writer = tensorboardX.SummaryWriter(log_dir=log)
###############################################################################
# Now load training and validation datasets into memory. Note that we need to
# subtracting energies by the self energies of all atoms for each molecule.
# This makes the range of energies in a reasonable range. The second argument
# defines how to convert species as a list of string to tensor, that is, for
# all supported chemical symbols, which is correspond to ``0``, which
# correspond to ``1``, etc.
training = torchani.data.BatchedANIDataset( training = torchani.data.BatchedANIDataset(
training_path, consts.species_to_tensor, batch_size, device=device, training_path, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset]) transform=[energy_shifter.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset( validation = torchani.data.BatchedANIDataset(
validation_path, consts.species_to_tensor, batch_size, device=device, validation_path, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset]) transform=[energy_shifter.subtract_from_dataset])
############################################################################### ###############################################################################
...@@ -140,77 +117,243 @@ validation = torchani.data.BatchedANIDataset( ...@@ -140,77 +117,243 @@ validation = torchani.data.BatchedANIDataset(
# #
# The output, i.e. ``properties`` is a dictionary holding each property. This # The output, i.e. ``properties`` is a dictionary holding each property. This
# allows us to extend TorchANI in the future to training forces and properties. # allows us to extend TorchANI in the future to training forces and properties.
#
# We have tools to deal with these data types at :attr:`torchani.ignite` that
# allow us to easily combine the dataset with pytorch ignite. These tools can
# be used as follows:
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
############################################################################### ###############################################################################
# Let's add a progress bar for the trainer # Now let's define atomic neural networks.
pbar = ignite.contrib.handlers.ProgressBar()
pbar.attach(trainer)
H_network = torch.nn.Sequential(
torch.nn.Linear(384, 160),
torch.nn.CELU(0.1),
torch.nn.Linear(160, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
############################################################################### C_network = torch.nn.Sequential(
# And some event handlers to compute validation and training metrics: torch.nn.Linear(384, 144),
def hartree2kcal(x): torch.nn.CELU(0.1),
return 627.509 * x torch.nn.Linear(144, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
N_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)
@trainer.on(ignite.engine.Events.EPOCH_STARTED) O_network = torch.nn.Sequential(
def validation_and_checkpoint(trainer): torch.nn.Linear(384, 128),
def evaluate(dataset, name): torch.nn.CELU(0.1),
evaluator = ignite.engine.create_supervised_evaluator( torch.nn.Linear(128, 112),
container, torch.nn.CELU(0.1),
metrics={ torch.nn.Linear(112, 96),
'RMSE': torchani.ignite.RMSEMetric('energies') torch.nn.CELU(0.1),
} torch.nn.Linear(96, 1)
) )
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar(name, rmse, trainer.state.epoch)
# compute validation RMSE nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
evaluate(validation, 'validation_rmse_vs_epoch') print(nn)
# compute training RMSE ###############################################################################
if trainer.state.epoch % training_rmse_every == 1: # Let's now create a pipeline of AEV Computer --> Neural Networks.
evaluate(training, 'training_rmse_vs_epoch') model = torch.nn.Sequential(aev_computer, nn).to(device)
# checkpoint model ###############################################################################
torch.save(nn.state_dict(), model_checkpoint) # Now let's setup the optimizer. We need to specify different weight decay rate
# for different parameters. Since PyTorch does not have correct implementation
# of weight decay right now, we provide the correct implementation at TorchANI.
#
# .. note::
#
# The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
# L2 regularization. The confusion between L2 and weight decay is a common
# mistake in deep learning. See: `Decoupled Weight Decay Regularization`_
# Also note that the weight decay only applies to weight in the training
# of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
# https://arxiv.org/abs/1711.05101
optimizer = torchani.optim.AdamW([
# H networks
{'params': [H_network[0].weight], 'weight_decay': 0.0001},
{'params': [H_network[0].bias]},
{'params': [H_network[2].weight], 'weight_decay': 0.00001},
{'params': [H_network[2].bias]},
{'params': [H_network[4].weight], 'weight_decay': 0.000001},
{'params': [H_network[4].bias]},
{'params': H_network[6].parameters()},
# C networks
{'params': [C_network[0].weight], 'weight_decay': 0.0001},
{'params': [C_network[0].bias]},
{'params': [C_network[2].weight], 'weight_decay': 0.00001},
{'params': [C_network[2].bias]},
{'params': [C_network[4].weight], 'weight_decay': 0.000001},
{'params': [C_network[4].bias]},
{'params': C_network[6].parameters()},
# N networks
{'params': [N_network[0].weight], 'weight_decay': 0.0001},
{'params': [N_network[0].bias]},
{'params': [N_network[2].weight], 'weight_decay': 0.00001},
{'params': [N_network[2].bias]},
{'params': [N_network[4].weight], 'weight_decay': 0.000001},
{'params': [N_network[4].bias]},
{'params': N_network[6].parameters()},
# O networks
{'params': [O_network[0].weight], 'weight_decay': 0.0001},
{'params': [O_network[0].bias]},
{'params': [O_network[2].weight], 'weight_decay': 0.00001},
{'params': [O_network[2].bias]},
{'params': [O_network[4].weight], 'weight_decay': 0.000001},
{'params': [O_network[4].bias]},
{'params': O_network[6].parameters()},
])
###############################################################################
# The way ANI trains a neural network potential looks like this:
#
# Phase 1: Pretrain the model by minimizing MSE loss
#
# Phase 2: Train the model by minimizing the exponential loss, until validation
# RMSE no longer improves for a certain steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a certain number.
#
# We first read the checkpoint files to find where we are. We use `latest.pt`
# to store current training state. If `latest.pt` does not exist, this
# this means the pretraining has not been finished yet.
latest_checkpoint = 'latest.pt'
pretrained = os.path.isfile(latest_checkpoint)
############################################################################### ###############################################################################
# Also some to log elapsed time: # If the model is not pretrained yet, we need to run the pretrain.
start = timeit.default_timer() pretrain_epoches = 10
mse = torch.nn.MSELoss(reduction='none')
if not pretrained:
print("pre-training...")
epoch = 0
for _ in range(pretrain_epoches):
for batch_x, batch_y in tqdm.tqdm(training):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
}, latest_checkpoint)
###############################################################################
# For phase 2, we need a learning rate scheduler to do learning rate decay
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=100)
@trainer.on(ignite.engine.Events.EPOCH_STARTED) ###############################################################################
def log_time(trainer): # We will also use TensorBoard to visualize our training process
elapsed = round(timeit.default_timer() - start, 2) tensorboard = torch.utils.tensorboard.SummaryWriter()
writer.add_scalar('time_vs_epoch', elapsed, trainer.state.epoch)
###############################################################################
# Resume training from previously saved checkpoints:
checkpoint = torch.load(latest_checkpoint)
nn.load_state_dict(checkpoint['nn'])
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler'])
############################################################################### ###############################################################################
# Also log the loss per iteration: # During training, we need to validate on validation set and if validation error
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED) # is better than the best, then save the new best model to a checkpoint
def log_loss(trainer):
iteration = trainer.state.iteration
writer.add_scalar('loss_vs_iteration', trainer.state.output, iteration) # helper function to convert energy unit from Hartree to kcal/mol
def hartree2kcal(x):
return 627.509 * x
def validate():
# run validation
mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0
count = 0
for batch_x, batch_y in validation:
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0]
return hartree2kcal(math.sqrt(total_mse / count))
############################################################################### ###############################################################################
# And finally, we are ready to run: # Finally, we come to the training loop.
trainer.run(training, max_epochs) #
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
print("training starting from epoch", scheduler.last_epoch + 1)
max_epochs = 200
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'
for _ in range(scheduler.last_epoch + 1, max_epochs):
rmse = validate()
learning_rate = optimizer.param_groups[0]['lr']
if learning_rate < early_stopping_learning_rate:
break
tensorboard.add_scalar('validation_rmse', rmse, scheduler.last_epoch)
tensorboard.add_scalar('best_validation_rmse', scheduler.best, scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, scheduler.last_epoch)
# checkpoint
if scheduler.is_better(rmse, scheduler.best):
torch.save(nn.state_dict(), best_model_checkpoint)
scheduler.step(rmse)
for i, (batch_x, batch_y) in tqdm.tqdm(enumerate(training), total=len(training)):
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms).mean()
loss = 0.5 * (torch.exp(2 * loss) - 1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# write current batch loss to TensorBoard
tensorboard.add_scalar('batch_loss', loss, scheduler.last_epoch * len(training) + i)
torch.save({
'nn': nn.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}, latest_checkpoint)
# -*- coding: utf-8 -*-
"""
.. _training-example-ignite:
Train Your Own Neural Network Potential, Using PyTorch-Ignite
=============================================================
We have seen how to train a neural network potential by manually writing
training loop in :ref:`training-example`. TorchANI provide tools to work
with PyTorch-Ignite to simplify the writing of training code. This tutorial
shows how to use these tools to train a demo model.
This tutorial assumes readers have read :ref:`training-example`.
"""
###############################################################################
# To begin with, let's first import the modules we will use:
import torch
import ignite
import torchani
import timeit
import os
import ignite.contrib.handlers
import torch.utils.tensorboard
###############################################################################
# Now let's setup training hyperparameters and dataset.
# training and validation set
try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') # noqa: E501
# checkpoint file to save model when validation RMSE improves
model_checkpoint = 'model.pt'
# max epochs to run the training
max_epochs = 20
# Compute training RMSE every this steps. Since the training set is usually
# huge and the loss funcition does not directly gives us RMSE, we need to
# check the training RMSE to see overfitting.
training_rmse_every = 5
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# batch size
batch_size = 1024
# log directory for tensorboard
log = 'runs'
###############################################################################
# Instead of manually specifying hyperparameters as in :ref:`training-example`,
# here we will load them from files.
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
sae_file = os.path.join(path, '../torchani/resources/ani-1x_8x/sae_linfit.dat') # noqa: E501
consts = torchani.neurochem.Constants(const_file)
aev_computer = torchani.AEVComputer(**consts)
energy_shifter = torchani.neurochem.load_sae(sae_file)
###############################################################################
# Now let's define atomic neural networks. Here in this demo, we use the same
# size of neural network for all atom types, but this is not necessary.
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
nn = torchani.ANIModel([atomic() for _ in range(4)])
print(nn)
###############################################################################
# If checkpoint from previous training exists, then load it.
if os.path.isfile(model_checkpoint):
nn.load_state_dict(torch.load(model_checkpoint))
else:
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now setup tensorboard
writer = torch.utils.tensorboard.SummaryWriter(log_dir=log)
###############################################################################
# Now load training and validation datasets into memory.
training = torchani.data.BatchedANIDataset(
training_path, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset(
validation_path, consts.species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset])
###############################################################################
# We have tools to deal with the chunking (see :ref:`training-example`). These
# tools can be used as follows:
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
###############################################################################
# Let's add a progress bar for the trainer
pbar = ignite.contrib.handlers.ProgressBar()
pbar.attach(trainer)
###############################################################################
# And some event handlers to compute validation and training metrics:
def hartree2kcal(x):
return 627.509 * x
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
def evaluate(dataset, name):
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar(name, rmse, trainer.state.epoch)
# compute validation RMSE
evaluate(validation, 'validation_rmse_vs_epoch')
# compute training RMSE
if trainer.state.epoch % training_rmse_every == 1:
evaluate(training, 'training_rmse_vs_epoch')
# checkpoint model
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# Also some to log elapsed time:
start = timeit.default_timer()
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_time(trainer):
elapsed = round(timeit.default_timer() - start, 2)
writer.add_scalar('time_vs_epoch', elapsed, trainer.state.epoch)
###############################################################################
# Also log the loss per iteration:
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
iteration = trainer.state.iteration
writer.add_scalar('loss_vs_iteration', trainer.state.output, iteration)
###############################################################################
# And finally, we are ready to run:
trainer.run(training, max_epochs)
...@@ -14,17 +14,18 @@ setup_attrs = { ...@@ -14,17 +14,18 @@ setup_attrs = {
'setup_requires': ['setuptools_scm'], 'setup_requires': ['setuptools_scm'],
'install_requires': [ 'install_requires': [
'torch-nightly', 'torch-nightly',
'pytorch-ignite-nightly',
'lark-parser', 'lark-parser',
'h5py',
], ],
'test_suite': 'nose.collector', 'test_suite': 'nose.collector',
'tests_require': [ 'tests_require': [
'nose', 'nose',
'tensorboardX', 'tb-nightly',
'tqdm', 'tqdm',
'ase', 'ase',
'coverage', 'coverage',
'h5py',
'pytorch-ignite-nightly',
'pillow',
], ],
} }
......
...@@ -30,11 +30,9 @@ from .aev import AEVComputer ...@@ -30,11 +30,9 @@ from .aev import AEVComputer
from . import utils from . import utils
from . import neurochem from . import neurochem
from . import models from . import models
from . import optim
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import get_distribution, DistributionNotFound
import sys import sys
if sys.version_info[0] > 2:
from . import ignite
from . import data
try: try:
__version__ = get_distribution(__name__).version __version__ = get_distribution(__name__).version
...@@ -43,10 +41,20 @@ except DistributionNotFound: ...@@ -43,10 +41,20 @@ except DistributionNotFound:
pass pass
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', __all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
'ignite', 'utils', 'neurochem', 'data', 'models'] 'utils', 'neurochem', 'models', 'optim']
try: try:
from . import ase # noqa: F401 from . import ase # noqa: F401
__all__.append('ase') __all__.append('ase')
except ImportError: except ImportError:
pass pass
if sys.version_info[0] > 2:
try:
from . import ignite # noqa: F401
__all__.append('ignite')
from . import data # noqa: F401
__all__.append('data')
except ImportError:
pass
...@@ -111,10 +111,10 @@ def RMSEMetric(key): ...@@ -111,10 +111,10 @@ def RMSEMetric(key):
return DictMetric(key, RootMeanSquaredError()) return DictMetric(key, RootMeanSquaredError())
def MAEMetric(key): def MaxAEMetric(key):
"""Create max absolute error metric on key.""" """Create max absolute error metric on key."""
return DictMetric(key, MaximumAbsoluteError()) return DictMetric(key, MaximumAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric', __all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
'MAEMetric'] 'MaxAEMetric']
...@@ -8,7 +8,6 @@ import bz2 ...@@ -8,7 +8,6 @@ import bz2
import lark import lark
import struct import struct
import itertools import itertools
import ignite
import math import math
import timeit import timeit
from . import _six # noqa:F401 from . import _six # noqa:F401
...@@ -17,7 +16,7 @@ import sys ...@@ -17,7 +16,7 @@ import sys
from ..nn import ANIModel, Ensemble, Gaussian from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter, ChemicalSymbolsToInts from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer from ..aev import AEVComputer
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric from ..optim import AdamW
class Constants(collections.abc.Mapping): class Constants(collections.abc.Mapping):
...@@ -380,8 +379,6 @@ def hartree2kcal(x): ...@@ -380,8 +379,6 @@ def hartree2kcal(x):
if sys.version_info[0] > 2: if sys.version_info[0] > 2:
from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class Trainer: class Trainer:
"""Train with NeuroChem training configurations. """Train with NeuroChem training configurations.
...@@ -391,7 +388,7 @@ if sys.version_info[0] > 2: ...@@ -391,7 +388,7 @@ if sys.version_info[0] > 2:
device (:class:`torch.device`): device to train the model device (:class:`torch.device`): device to train the model
tqdm (bool): whether to enable tqdm tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to tensorboard (str): Directory to store tensorboard log file, set to
``None`` to disable tensorboardX. ``None`` to disable tensorboard.
aev_caching (bool): Whether to use AEV caching. aev_caching (bool): Whether to use AEV caching.
checkpoint_name (str): Name of the checkpoint file, checkpoints checkpoint_name (str): Name of the checkpoint file, checkpoints
will be stored in the network directory with this file name. will be stored in the network directory with this file name.
...@@ -400,6 +397,30 @@ if sys.version_info[0] > 2: ...@@ -400,6 +397,30 @@ if sys.version_info[0] > 2:
def __init__(self, filename, device=torch.device('cuda'), tqdm=False, def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, aev_caching=False, tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'): checkpoint_name='model.pt'):
try:
import ignite
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MaxAEMetric
from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
except ImportError:
raise RuntimeError(
'NeuroChem Trainer requires ignite,'
'please install pytorch-ignite-nightly from PYPI')
self.ignite = ignite
class dummy:
pass
self.imports = dummy()
self.imports.Container = Container
self.imports.MSELoss = MSELoss
self.imports.TransformedLoss = TransformedLoss
self.imports.RMSEMetric = RMSEMetric
self.imports.MaxAEMetric = MaxAEMetric
self.imports.BatchedANIDataset = BatchedANIDataset
self.imports.AEVCacheLoader = AEVCacheLoader
self.filename = filename self.filename = filename
self.device = device self.device = device
self.aev_caching = aev_caching self.aev_caching = aev_caching
...@@ -411,8 +432,8 @@ if sys.version_info[0] > 2: ...@@ -411,8 +432,8 @@ if sys.version_info[0] > 2:
else: else:
self.tqdm = None self.tqdm = None
if tensorboard is not None: if tensorboard is not None:
import tensorboardX import torch.utils.tensorboard
self.tensorboard = tensorboardX.SummaryWriter( self.tensorboard = torch.utils.tensorboard.SummaryWriter(
log_dir=tensorboard) log_dir=tensorboard)
self.training_eval_every = 20 self.training_eval_every = 20
else: else:
...@@ -612,9 +633,12 @@ if sys.version_info[0] > 2: ...@@ -612,9 +633,12 @@ if sys.version_info[0] > 2:
if 'l2norm' in layer: if 'l2norm' in layer:
if layer['l2norm'] == 1: if layer['l2norm'] == 1:
self.parameters.append({ self.parameters.append({
'params': module.parameters(), 'params': [module.weight],
'weight_decay': layer['l2valu'], 'weight_decay': layer['l2valu'],
}) })
self.parameters.append({
'params': [module.bias],
})
else: else:
self.parameters.append({ self.parameters.append({
'params': module.parameters(), 'params': module.parameters(),
...@@ -636,12 +660,12 @@ if sys.version_info[0] > 2: ...@@ -636,12 +660,12 @@ if sys.version_info[0] > 2:
self.nnp = self.model self.nnp = self.model
else: else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model) self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = Container({'energies': self.nnp}).to(self.device) self.container = self.imports.Container({'energies': self.nnp}).to(self.device)
# losses # losses
self.mse_loss = MSELoss('energies') self.mse_loss = self.imports.MSELoss('energies')
self.exp_loss = TransformedLoss( self.exp_loss = self.imports.TransformedLoss(
MSELoss('energies'), self.imports.MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1)) lambda x: 0.5 * (torch.exp(2 * x) - 1))
if params: if params:
...@@ -652,17 +676,17 @@ if sys.version_info[0] > 2: ...@@ -652,17 +676,17 @@ if sys.version_info[0] > 2:
self.best_validation_rmse = math.inf self.best_validation_rmse = math.inf
def evaluate(self, dataset): def evaluate(self, dataset):
"""Evaluate on given dataset to compute RMSE and MAE.""" """Evaluate on given dataset to compute RMSE and MaxAE."""
evaluator = ignite.engine.create_supervised_evaluator( evaluator = self.ignite.engine.create_supervised_evaluator(
self.container, self.container,
metrics={ metrics={
'RMSE': RMSEMetric('energies'), 'RMSE': self.imports.RMSEMetric('energies'),
'MAE': MAEMetric('energies'), 'MaxAE': self.imports.MaxAEMetric('energies'),
} }
) )
evaluator.run(dataset) evaluator.run(dataset)
metrics = evaluator.state.metrics metrics = evaluator.state.metrics
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']) return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MaxAE'])
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file. """Load training and validation dataset from file.
...@@ -671,14 +695,14 @@ if sys.version_info[0] > 2: ...@@ -671,14 +695,14 @@ if sys.version_info[0] > 2:
directory, otherwise it should be path to the dataset. directory, otherwise it should be path to the dataset.
""" """
if self.aev_caching: if self.aev_caching:
self.training_set = AEVCacheLoader(training_path) self.training_set = self.imports.AEVCacheLoader(training_path)
self.validation_set = AEVCacheLoader(validation_path) self.validation_set = self.imports.AEVCacheLoader(validation_path)
else: else:
self.training_set = BatchedANIDataset( self.training_set = self.imports.BatchedANIDataset(
training_path, self.consts.species_to_tensor, training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device, self.training_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset]) transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = BatchedANIDataset( self.validation_set = self.imports.BatchedANIDataset(
validation_path, self.consts.species_to_tensor, validation_path, self.consts.species_to_tensor,
self.validation_batch_size, device=self.device, self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset]) transform=[self.shift_energy.subtract_from_dataset])
...@@ -689,38 +713,38 @@ if sys.version_info[0] > 2: ...@@ -689,38 +713,38 @@ if sys.version_info[0] > 2:
def decorate(trainer): def decorate(trainer):
@trainer.on(ignite.engine.Events.STARTED) @trainer.on(self.ignite.engine.Events.STARTED)
def initialize(trainer): def initialize(trainer):
trainer.state.no_improve_count = 0 trainer.state.no_improve_count = 0
trainer.state.epoch += self.global_epoch trainer.state.epoch += self.global_epoch
trainer.state.iteration += self.global_iteration trainer.state.iteration += self.global_iteration
@trainer.on(ignite.engine.Events.COMPLETED) @trainer.on(self.ignite.engine.Events.COMPLETED)
def finalize(trainer): def finalize(trainer):
self.global_epoch = trainer.state.epoch self.global_epoch = trainer.state.epoch
self.global_iteration = trainer.state.iteration self.global_iteration = trainer.state.iteration
if self.nmax > 0: if self.nmax > 0:
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED) @trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
def terminate_when_nmax_reaches(trainer): def terminate_when_nmax_reaches(trainer):
if trainer.state.epoch >= self.nmax: if trainer.state.epoch >= self.nmax:
trainer.terminate() trainer.terminate()
if self.tqdm is not None: if self.tqdm is not None:
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer): def init_tqdm(trainer):
trainer.state.tqdm = self.tqdm( trainer.state.tqdm = self.tqdm(
total=len(self.training_set), desc='epoch') total=len(self.training_set), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED) @trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer): def update_tqdm(trainer):
trainer.state.tqdm.update(1) trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED) @trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer): def finalize_tqdm(trainer):
trainer.state.tqdm.close() trainer.state.tqdm.close()
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer): def validation_and_checkpoint(trainer):
trainer.state.rmse, trainer.state.mae = \ trainer.state.rmse, trainer.state.mae = \
self.evaluate(self.validation_set) self.evaluate(self.validation_set)
...@@ -736,7 +760,7 @@ if sys.version_info[0] > 2: ...@@ -736,7 +760,7 @@ if sys.version_info[0] > 2:
trainer.terminate() trainer.terminate()
if self.tensorboard is not None: if self.tensorboard is not None:
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def log_per_epoch(trainer): def log_per_epoch(trainer):
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
epoch = trainer.state.epoch epoch = trainer.state.epoch
...@@ -764,7 +788,7 @@ if sys.version_info[0] > 2: ...@@ -764,7 +788,7 @@ if sys.version_info[0] > 2:
self.tensorboard.add_scalar( self.tensorboard.add_scalar(
'training_mae_vs_epoch', training_mae, epoch) 'training_mae_vs_epoch', training_mae, epoch)
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED) @trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer): def log_loss(trainer):
iteration = trainer.state.iteration iteration = trainer.state.iteration
loss = trainer.state.output loss = trainer.state.output
...@@ -775,12 +799,12 @@ if sys.version_info[0] > 2: ...@@ -775,12 +799,12 @@ if sys.version_info[0] > 2:
# training using mse loss first until the validation MAE decrease # training using mse loss first until the validation MAE decrease
# to < 1 Hartree # to < 1 Hartree
optimizer = torch.optim.Adam(self.parameters, lr=lr) optimizer = AdamW(self.parameters, lr=lr)
trainer = ignite.engine.create_supervised_trainer( trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.mse_loss) self.container, optimizer, self.mse_loss)
decorate(trainer) decorate(trainer)
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
def terminate_if_smaller_enough(trainer): def terminate_if_smaller_enough(trainer):
if trainer.state.mae < 1.0: if trainer.state.mae < 1.0:
trainer.terminate() trainer.terminate()
...@@ -788,8 +812,8 @@ if sys.version_info[0] > 2: ...@@ -788,8 +812,8 @@ if sys.version_info[0] > 2:
trainer.run(self.training_set, max_epochs=math.inf) trainer.run(self.training_set, max_epochs=math.inf)
while lr > self.min_lr: while lr > self.min_lr:
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) optimizer = AdamW(self.model.parameters(), lr=lr)
trainer = ignite.engine.create_supervised_trainer( trainer = self.ignite.engine.create_supervised_trainer(
self.container, optimizer, self.exp_loss) self.container, optimizer, self.exp_loss)
decorate(trainer) decorate(trainer)
trainer.run(self.training_set, max_epochs=math.inf) trainer.run(self.training_set, max_epochs=math.inf)
......
"""AdamW implementation"""
import math
import torch
from torch.optim.optimizer import Optimizer
# Copied and modified from: https://github.com/pytorch/pytorch/pull/4429
class AdamW(Optimizer):
r"""Implements AdamW algorithm.
It has been proposed in `Decoupled Weight Decay Regularization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay factor (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
import torch import torch
import torch.utils.data
import math import math
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment