Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
4eb539e7
Unverified
Commit
4eb539e7
authored
Dec 12, 2018
by
Gao, Xiang
Committed by
GitHub
Dec 12, 2018
Browse files
Remove MaxAE from TorchANI (#150)
parent
2cc64cb7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
21 deletions
+2
-21
torchani/ignite.py
torchani/ignite.py
+2
-21
No files found.
torchani/ignite.py
View file @
4eb539e7
...
@@ -6,6 +6,7 @@ from . import utils
...
@@ -6,6 +6,7 @@ from . import utils
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
ignite.metrics.metric
import
Metric
from
ignite.metrics.metric
import
Metric
from
ignite.metrics
import
RootMeanSquaredError
from
ignite.metrics
import
RootMeanSquaredError
from
ignite.contrib.metrics.regression
import
MaximumAbsoluteError
class
Container
(
torch
.
nn
.
ModuleDict
):
class
Container
(
torch
.
nn
.
ModuleDict
):
...
@@ -111,29 +112,9 @@ def RMSEMetric(key):
...
@@ -111,29 +112,9 @@ def RMSEMetric(key):
return
DictMetric
(
key
,
RootMeanSquaredError
())
return
DictMetric
(
key
,
RootMeanSquaredError
())
class
MaxAbsoluteError
(
Metric
):
"""
Calculates the max absolute error.
- `update` must receive output of the form `(y_pred, y)`.
"""
def
reset
(
self
):
self
.
_max_of_absolute_errors
=
0.0
def
update
(
self
,
output
):
y_pred
,
y
=
output
absolute_errors
=
torch
.
abs
(
y_pred
-
y
.
view_as
(
y_pred
))
batch_max
=
absolute_errors
.
max
().
item
()
if
batch_max
>
self
.
_max_of_absolute_errors
:
self
.
_max_of_absolute_errors
=
batch_max
def
compute
(
self
):
return
self
.
_max_of_absolute_errors
def
MAEMetric
(
key
):
def
MAEMetric
(
key
):
"""Create max absolute error metric on key."""
"""Create max absolute error metric on key."""
return
DictMetric
(
key
,
MaxAbsoluteError
())
return
DictMetric
(
key
,
Max
imum
AbsoluteError
())
__all__
=
[
'Container'
,
'MSELoss'
,
'TransformedLoss'
,
'RMSEMetric'
,
__all__
=
[
'Container'
,
'MSELoss'
,
'TransformedLoss'
,
'RMSEMetric'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment