Commit 123e4760 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Removing ignite (#354)

* Remove ignite

* remove neurochem-test

* remove cache-aev
parent 566320ec
...@@ -22,7 +22,7 @@ jobs: ...@@ -22,7 +22,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
ci/install_dependencies.sh ci/install_dependencies.sh
pip install h5py pytorch-ignite tb-nightly sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery pip install h5py tb-nightly sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery
pip install . pip install .
- name: Download data files - name: Download data files
run: ./download.sh run: ./download.sh
......
...@@ -20,7 +20,7 @@ jobs: ...@@ -20,7 +20,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
ci/install_dependencies.sh ci/install_dependencies.sh
pip install h5py pytorch-ignite tb-nightly sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery pip install h5py tb-nightly sphinx sphinx_rtd_theme matplotlib pillow sphinx-gallery
pip install . pip install .
- name: Download data files - name: Download data files
run: ./download.sh run: ./download.sh
......
...@@ -31,9 +31,5 @@ jobs: ...@@ -31,9 +31,5 @@ jobs:
run: pip install h5py run: pip install h5py
- name: COMP6 Benchmark - name: COMP6 Benchmark
run: python tools/comp6.py dataset/COMP6/COMP6v1/s66x8 run: python tools/comp6.py dataset/COMP6/COMP6v1/s66x8
- name: Install more dependencies
run: pip install pytorch-ignite
- name: Training Benchmark - name: Training Benchmark
run: python tools/training-benchmark.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 run: python tools/training-benchmark.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
- name: NeuroChem Test
run: python tools/neurochem-test.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
...@@ -12,7 +12,7 @@ jobs: ...@@ -12,7 +12,7 @@ jobs:
python-version: [3.6, 3.7] python-version: [3.6, 3.7]
test-filenames: [ test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py, test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_ignite.py, test_utils.py, test_ase.py, test_energies.py, test_data_new.py, test_utils.py, test_ase.py, test_energies.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py, test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py] test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
...@@ -77,18 +77,3 @@ TorchANI Optimizater ...@@ -77,18 +77,3 @@ TorchANI Optimizater
.. automodule:: torchani.optim .. automodule:: torchani.optim
.. autoclass:: torchani.optim.AdamW .. autoclass:: torchani.optim.AdamW
Ignite Helpers
==============
.. automodule:: torchani.ignite
.. autoclass:: torchani.ignite.Container
:members:
.. autoclass:: torchani.ignite.DictLoss
.. autoclass:: torchani.ignite.PerAtomDictLoss
.. autoclass:: torchani.ignite.TransformedLoss
.. autofunction:: torchani.ignite.MSELoss
.. autoclass:: torchani.ignite.DictMetric
.. autofunction:: torchani.ignite.RMSEMetric
.. autofunction:: torchani.ignite.MaxAEMetric
...@@ -22,7 +22,6 @@ Welcome to TorchANI's documentation! ...@@ -22,7 +22,6 @@ Welcome to TorchANI's documentation!
examples/nnp_training examples/nnp_training
examples/nnp_training_force examples/nnp_training_force
examples/nnp_training_ignite examples/nnp_training_ignite
examples/cache_aev
examples/neurochem_trainer examples/neurochem_trainer
.. toctree:: .. toctree::
......
# -*- coding: utf-8 -*-
"""
Use Disk Cache of AEV to Boost Training
=======================================
In the previous :ref:`training-example` example, AEVs are computed everytime
when needed. This is not very efficient because the AEVs actually never change
during training. If one has a good SSD, it would be beneficial to cache these
AEVs. This example shows how to use disk cache to boost training
"""
###############################################################################
# Most part of the codes in this example are line by line copy of
# :ref:`training-example`.
import torch
import ignite
import torchani
import timeit
import os
import ignite.contrib.handlers
import torch.utils.tensorboard
# training and validation set
try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
training_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
validation_path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') # noqa: E501
# checkpoint file to save model when validation RMSE improves
model_checkpoint = 'model.pt'
# max epochs to run the training
max_epochs = 20
# Compute training RMSE every this steps. Since the training set is usually
# huge and the loss funcition does not directly gives us RMSE, we need to
# check the training RMSE to see overfitting.
training_rmse_every = 5
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# batch size
batch_size = 1024
# log directory for tensorboardX
log = 'runs'
###############################################################################
# Here, there is no need to manually construct aev computer and energy shifter,
# but we do need to generate a disk cache for datasets
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params')
sae_file = os.path.join(path, '../torchani/resources/ani-1x_8x/sae_linfit.dat')
training_cache = './training_cache'
validation_cache = './validation_cache'
# If the cache dirs already exists, then we assume these data has already been
# cached and skip the generation part.
if not os.path.exists(training_cache):
torchani.data.cache_aev(training_cache, training_path, batch_size, device,
const_file, True, sae_file)
if not os.path.exists(validation_cache):
torchani.data.cache_aev(validation_cache, validation_path, batch_size,
device, const_file, True, sae_file)
###############################################################################
# The codes that define the network are also the same
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
nn = torchani.ANIModel([atomic() for _ in range(4)])
print(nn)
if os.path.isfile(model_checkpoint):
nn.load_state_dict(torch.load(model_checkpoint))
else:
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# Except that at here we do not include aev computer into our pipeline, because
# the cache loader will load computed AEVs from disk.
model = nn.to(device)
###############################################################################
# This part is also a line by line copy
writer = torch.utils.tensorboard.SummaryWriter(log_dir=log)
###############################################################################
# Here we don't need to construct :class:`torchani.data.BatchedANIDataset`
# object, but instead an object of :class:`torchani.data.AEVCacheLoader`
training = torchani.data.AEVCacheLoader(training_cache)
validation = torchani.data.AEVCacheLoader(validation_cache)
###############################################################################
# The rest of the code are again the same
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
###############################################################################
# Let's add a progress bar for the trainer
pbar = ignite.contrib.handlers.ProgressBar()
pbar.attach(trainer)
def hartree2kcal(x):
return 627.509 * x
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
def evaluate(dataset, name):
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar(name, rmse, trainer.state.epoch)
# compute validation RMSE
evaluate(validation, 'validation_rmse_vs_epoch')
# compute training RMSE
if trainer.state.epoch % training_rmse_every == 1:
evaluate(training, 'training_rmse_vs_epoch')
# checkpoint model
torch.save(nn.state_dict(), model_checkpoint)
start = timeit.default_timer()
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_time(trainer):
elapsed = round(timeit.default_timer() - start, 2)
writer.add_scalar('time_vs_epoch', elapsed, trainer.state.epoch)
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
iteration = trainer.state.iteration
writer.add_scalar('loss_vs_iteration', trainer.state.output, iteration)
trainer.run(training, max_epochs)
...@@ -40,8 +40,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ...@@ -40,8 +40,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# .. note:: # .. note::
# #
# Besides defining these hyperparameters programmatically, # Besides defining these hyperparameters programmatically,
# :mod:`torchani.neurochem` provide tools to read them from file. See also # :mod:`torchani.neurochem` provide tools to read them from file.
# :ref:`training-example-ignite` for an example of usage.
# #
# .. _rHCNO-5.2R_16-3.5A_a4-8.params: # .. _rHCNO-5.2R_16-3.5A_a4-8.params:
# https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params # https://github.com/aiqm/torchani/blob/master/torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params
......
# -*- coding: utf-8 -*-
"""
.. _training-example-ignite:
Train Your Own Neural Network Potential, Using PyTorch-Ignite
=============================================================
We have seen how to train a neural network potential by manually writing
training loop in :ref:`training-example`. TorchANI provide tools to work
with PyTorch-Ignite to simplify the writing of training code. This tutorial
shows how to use these tools to train a demo model. The setup in this demo is
not necessarily identical to NeuroChem.
This tutorial assumes readers have read :ref:`training-example`.
"""
###############################################################################
# To begin with, let's first import the modules and setup devices we will use:
import torch
import ignite
import torchani
import timeit
import os
import ignite.contrib.handlers
import torch.utils.tensorboard
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###############################################################################
# Now let's setup training hyperparameters and dataset.
# training and validation set
try:
path = os.path.dirname(os.path.realpath(__file__))
except NameError:
path = os.getcwd()
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
# checkpoint file to save model when validation RMSE improves
model_checkpoint = 'model.pt'
# max epochs to run the training
max_epochs = 20
# Compute training RMSE every this steps. Since the training set is usually
# huge and the loss funcition does not directly gives us RMSE, we need to
# check the training RMSE to see overfitting.
training_rmse_every = 5
# batch size
batch_size = 2560
# log directory for tensorboard
log = 'runs'
###############################################################################
# Instead of manually specifying hyperparameters as in :ref:`training-example`,
# here we will load them from files.
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
consts = torchani.neurochem.Constants(const_file)
aev_computer = torchani.AEVComputer(**consts)
energy_shifter = torchani.utils.EnergyShifter(None)
###############################################################################
# Now let's define atomic neural networks. Here in this demo, we use the same
# size of neural network for all atom types, but this is not necessary.
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
nn = torchani.ANIModel([atomic() for _ in range(4)])
print(nn)
###############################################################################
# If checkpoint from previous training exists, then load it.
if os.path.isfile(model_checkpoint):
nn.load_state_dict(torch.load(model_checkpoint))
else:
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torchani.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now setup tensorboard
writer = torch.utils.tensorboard.SummaryWriter(log_dir=log)
###############################################################################
# Now load training and validation datasets into memory.
training, validation = torchani.data.load_ani_dataset(
dspath, consts.species_to_tensor, batch_size, rm_outlier=True, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
###############################################################################
# We have tools to deal with the chunking (see :ref:`training-example`). These
# tools can be used as follows:
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
###############################################################################
# Let's add a progress bar for the trainer
pbar = ignite.contrib.handlers.ProgressBar()
pbar.attach(trainer)
###############################################################################
# And some event handlers to compute validation and training metrics:
def hartree2kcal(x):
return 627.509 * x
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
def evaluate(dataset, name):
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar(name, rmse, trainer.state.epoch)
# compute validation RMSE
evaluate(validation, 'validation_rmse_vs_epoch')
# compute training RMSE
if trainer.state.epoch % training_rmse_every == 1:
evaluate(training, 'training_rmse_vs_epoch')
# checkpoint model
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# Also some to log elapsed time:
start = timeit.default_timer()
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def log_time(trainer):
elapsed = round(timeit.default_timer() - start, 2)
writer.add_scalar('time_vs_epoch', elapsed, trainer.state.epoch)
###############################################################################
# Also log the loss per iteration:
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss(trainer):
iteration = trainer.state.iteration
writer.add_scalar('loss_vs_iteration', trainer.state.output, iteration)
###############################################################################
# And finally, we are ready to run:
trainer.run(training, max_epochs)
import os
import unittest
import torch
import copy
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator, Events
import torchani
import torchani.ignite
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batchsize = 4
threshold = 1e-5
class TestIgnite(unittest.TestCase):
def testIgnite(self):
ani1x = torchani.models.ANI1x()
aev_computer = ani1x.aev_computer
nnp = copy.deepcopy(ani1x.neural_networks[0])
shift_energy = ani1x.energy_shifter
ds = torchani.data.load_ani_dataset(
path, ani1x.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset],
device=aev_computer.EtaR.device)
ds = torch.utils.data.Subset(ds, [0])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
model = torchani.nn.Sequential(aev_computer, nnp, Flatten())
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters())
loss = torchani.ignite.TransformedLoss(
torchani.ignite.MSELoss('energies'),
lambda x: torch.exp(x) - 1)
trainer = create_supervised_trainer(
container, optimizer, loss)
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
@trainer.on(Events.COMPLETED)
def completes(trainer):
evaluator.run(ds)
metrics = evaluator.state.metrics
self.assertLess(metrics['RMSE'], threshold)
self.assertLess(trainer.state.output, threshold)
trainer.run(ds, max_epochs=1000)
if __name__ == '__main__':
unittest.main()
import os
import torch
import torchani
import ignite
import pickle
import argparse
ani1x = torchani.models.ANI1x()
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset. The path can be a hdf5 file or \
a directory containing hdf5 files. It can also be a file \
dumped by pickle.')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=1024, type=int)
parser.add_argument('--const_file',
help='File storing constants',
default=ani1x.const_file)
parser.add_argument('--sae_file',
help='File storing self atomic energies',
default=ani1x.sae_file)
parser.add_argument('--network_dir',
help='Directory or prefix of directories storing networks',
default=ani1x.ensemble_prefix + '0/networks')
parser.add_argument('--compare_with',
help='The TorchANI model to compare with', default=None)
parser = parser.parse_args()
# load modules and datasets
device = torch.device(parser.device)
consts = torchani.neurochem.Constants(parser.const_file)
shift_energy = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
model = torch.nn.Sequential(aev_computer, nn)
container = torchani.ignite.Container({'energies': model})
container = container.to(device)
# load datasets
if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path):
dataset = torchani.data.load_ani_dataset(
parser.dataset_path, consts.species_to_tensor, parser.batch_size,
device=device, transform=[shift_energy.subtract_from_dataset])
datasets = [dataset]
else:
with open(parser.dataset_path, 'rb') as f:
datasets = pickle.load(f)
if not isinstance(datasets, list) and not isinstance(datasets, tuple):
datasets = [datasets]
# prepare evaluator
def hartree2kcal(x):
return 627.509 * x
def evaluate(dataset, container):
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
print(rmse, 'kcal/mol')
for dataset in datasets:
evaluate(dataset, container)
if parser.compare_with is not None:
nn.load_state_dict(torch.load(parser.compare_with))
print('TorchANI results:')
for dataset in datasets:
evaluate(dataset, container)
...@@ -5,8 +5,7 @@ the `Roitberg group`_. TorchANI contains classes like ...@@ -5,8 +5,7 @@ the `Roitberg group`_. TorchANI contains classes like
be pipelined to compute molecular energies from the 3D coordinates of be pipelined to compute molecular energies from the 3D coordinates of
molecules. It also include tools to: deal with ANI datasets(e.g. `ANI-1`_, molecules. It also include tools to: deal with ANI datasets(e.g. `ANI-1`_,
`ANI-1x`_, `ANI-1ccx`_, etc.) at :attr:`torchani.data`, import various file `ANI-1x`_, `ANI-1ccx`_, etc.) at :attr:`torchani.data`, import various file
formats of NeuroChem at :attr:`torchani.neurochem`, help working with ignite formats of NeuroChem at :attr:`torchani.neurochem`, and more at :attr:`torchani.utils`.
at :attr:`torchani.ignite`, and more at :attr:`torchani.utils`.
.. _ANI: .. _ANI:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
...@@ -51,12 +50,6 @@ except ImportError: ...@@ -51,12 +50,6 @@ except ImportError:
if sys.version_info[0] > 2: if sys.version_info[0] > 2:
try:
from . import ignite # noqa: F401
__all__.append('ignite')
except ImportError:
pass
try: try:
from . import data # noqa: F401 from . import data # noqa: F401
__all__.append('data') __all__.append('data')
......
# -*- coding: utf-8 -*-
"""Helpers for working with ignite."""
from __future__ import absolute_import
import torch
from . import utils
from torch.nn.modules.loss import _Loss
from ignite.metrics import Metric, RootMeanSquaredError, MeanAbsoluteError
from ignite.contrib.metrics.regression import MaximumAbsoluteError
class Container(torch.nn.ModuleDict):
r"""Each minibatch is splitted into chunks, as explained in the docstring of
:method:`torchani.data.load_ani_dataset`, as a result, it is impossible to
use :class:`torchani.AEVComputer`, :class:`torchani.ANIModel` directly with
ignite. This class is designed to solve this issue.
Arguments:
modules (:class:`collections.abc.Mapping`): same as the argument in
:class:`torch.nn.ModuleDict`.
"""
def __init__(self, modules):
super(Container, self).__init__(modules)
def forward(self, species_x):
"""Takes sequence of species, coordinates pair as input, and returns
computed properties as a dictionary. Same property from different
chunks will be concatenated to form a single tensor for a batch.
"""
results = {k: [] for k in self}
for sx in species_x:
for k in self:
_, result = self[k](tuple(sx))
results[k].append(result)
for k in self:
results[k] = torch.cat(results[k])
results['species'] = utils.pad([s for s, _ in species_x])
return results
class DictLoss(_Loss):
"""Since :class:`Container` output dictionaries, losses defined in
:attr:`torch.nn` needs to be wrapped before used. This class wraps losses
that directly work on tensors with a key by calling the wrapped loss on the
associated value of that key.
"""
def __init__(self, key, loss):
super(DictLoss, self).__init__()
self.key = key
self.loss = loss
def forward(self, input_, other):
return self.loss(input_[self.key], other[self.key])
class PerAtomDictLoss(DictLoss):
"""Similar to :class:`DictLoss`, but scale the loss values by the number of
atoms for each structure. The `loss` argument must be set to not to reduce
by the caller. Currently the only reduce operation supported is averaging.
"""
def forward(self, input_, other):
loss = self.loss(input_[self.key], other[self.key])
num_atoms = (input_['species'] >= 0).to(loss.dtype).to(loss.device).sum(dim=1)
loss /= num_atoms
n = loss.numel()
return loss.sum() / n
class DictMetric(Metric):
"""Similar to :class:`DictLoss`, but this is for metric, not loss."""
def __init__(self, key, metric):
self.key = key
self.metric = metric
super(DictMetric, self).__init__()
def reset(self):
self.metric.reset()
def update(self, output):
y_pred, y = output
self.metric.update((y_pred[self.key], y[self.key]))
def compute(self):
return self.metric.compute()
def MSELoss(key, per_atom=True):
"""Create MSE loss on the specified key."""
if per_atom:
return PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none'))
return DictLoss(key, torch.nn.MSELoss())
class TransformedLoss(_Loss):
"""Do a transformation on loss values."""
def __init__(self, origin, transform):
super(TransformedLoss, self).__init__()
self.origin = origin
self.transform = transform
def forward(self, input_, other):
return self.transform(self.origin(input_, other))
def RMSEMetric(key):
"""Create RMSE metric on key."""
return DictMetric(key, RootMeanSquaredError())
def MaxAEMetric(key):
"""Create max absolute error metric on key."""
return DictMetric(key, MaximumAbsoluteError())
def MAEMetric(key):
"""Create max absolute error metric on key."""
return DictMetric(key, MeanAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
'MaxAEMetric']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment