Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
59b31d84
"vscode:/vscode.git/clone" did not exist on "ab0459f2b7685d54b8b4ea1578eeda6ddece0913"
Unverified
Commit
59b31d84
authored
Aug 03, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 03, 2018
Browse files
implement per atom loss (#51)
parent
84439caf
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
10 deletions
+22
-10
torchani/ignite/loss_metrics.py
torchani/ignite/loss_metrics.py
+22
-10
No files found.
torchani/ignite/loss_metrics.py
View file @
59b31d84
...
@@ -4,14 +4,6 @@ from ignite.metrics import RootMeanSquaredError
...
@@ -4,14 +4,6 @@ from ignite.metrics import RootMeanSquaredError
import
torch
import
torch
def
num_atoms
(
input
):
ret
=
[]
for
s
,
c
in
zip
(
input
[
'species'
],
input
[
'coordinates'
]):
ret
.
append
(
torch
.
full
((
c
.
shape
[
0
],),
len
(
s
),
dtype
=
c
.
dtype
,
device
=
c
.
device
))
return
torch
.
cat
(
ret
)
class
DictLoss
(
_Loss
):
class
DictLoss
(
_Loss
):
def
__init__
(
self
,
key
,
loss
):
def
__init__
(
self
,
key
,
loss
):
...
@@ -23,6 +15,23 @@ class DictLoss(_Loss):
...
@@ -23,6 +15,23 @@ class DictLoss(_Loss):
return
self
.
loss
(
input
[
self
.
key
],
other
[
self
.
key
])
return
self
.
loss
(
input
[
self
.
key
],
other
[
self
.
key
])
class
_PerAtomDictLoss
(
DictLoss
):
@
staticmethod
def
num_atoms
(
input
):
ret
=
[]
for
s
,
c
in
zip
(
input
[
'species'
],
input
[
'coordinates'
]):
ret
.
append
(
torch
.
full
((
c
.
shape
[
0
],),
len
(
s
),
dtype
=
c
.
dtype
,
device
=
c
.
device
))
return
torch
.
cat
(
ret
)
def
forward
(
self
,
input
,
other
):
loss
=
self
.
loss
(
input
[
self
.
key
],
other
[
self
.
key
])
loss
/=
self
.
num_atoms
(
input
)
n
=
loss
.
numel
()
return
loss
.
sum
()
/
n
class
DictMetric
(
Metric
):
class
DictMetric
(
Metric
):
def
__init__
(
self
,
key
,
metric
):
def
__init__
(
self
,
key
,
metric
):
...
@@ -41,7 +50,10 @@ class DictMetric(Metric):
...
@@ -41,7 +50,10 @@ class DictMetric(Metric):
return
self
.
metric
.
compute
()
return
self
.
metric
.
compute
()
def
MSELoss
(
key
):
def
MSELoss
(
key
,
per_atom
=
True
):
if
per_atom
:
return
_PerAtomDictLoss
(
key
,
torch
.
nn
.
MSELoss
(
reduce
=
False
))
else
:
return
DictLoss
(
key
,
torch
.
nn
.
MSELoss
())
return
DictLoss
(
key
,
torch
.
nn
.
MSELoss
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment