Unverified Commit 4eb539e7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove MaxAE from TorchANI (#150)

parent 2cc64cb7
...@@ -6,6 +6,7 @@ from . import utils ...@@ -6,6 +6,7 @@ from . import utils
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric from ignite.metrics.metric import Metric
from ignite.metrics import RootMeanSquaredError from ignite.metrics import RootMeanSquaredError
from ignite.contrib.metrics.regression import MaximumAbsoluteError
class Container(torch.nn.ModuleDict): class Container(torch.nn.ModuleDict):
...@@ -111,29 +112,9 @@ def RMSEMetric(key): ...@@ -111,29 +112,9 @@ def RMSEMetric(key):
return DictMetric(key, RootMeanSquaredError()) return DictMetric(key, RootMeanSquaredError())
class MaxAbsoluteError(Metric):
"""
Calculates the max absolute error.
- `update` must receive output of the form `(y_pred, y)`.
"""
def reset(self):
self._max_of_absolute_errors = 0.0
def update(self, output):
y_pred, y = output
absolute_errors = torch.abs(y_pred - y.view_as(y_pred))
batch_max = absolute_errors.max().item()
if batch_max > self._max_of_absolute_errors:
self._max_of_absolute_errors = batch_max
def compute(self):
return self._max_of_absolute_errors
def MAEMetric(key): def MAEMetric(key):
"""Create max absolute error metric on key.""" """Create max absolute error metric on key."""
return DictMetric(key, MaxAbsoluteError()) return DictMetric(key, MaximumAbsoluteError())
__all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric', __all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment