Unverified Commit 59b31d84 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

implement per atom loss (#51)

parent 84439caf
...@@ -4,14 +4,6 @@ from ignite.metrics import RootMeanSquaredError ...@@ -4,14 +4,6 @@ from ignite.metrics import RootMeanSquaredError
import torch import torch
def num_atoms(input):
ret = []
for s, c in zip(input['species'], input['coordinates']):
ret.append(torch.full((c.shape[0],), len(s),
dtype=c.dtype, device=c.device))
return torch.cat(ret)
class DictLoss(_Loss): class DictLoss(_Loss):
def __init__(self, key, loss): def __init__(self, key, loss):
...@@ -23,6 +15,23 @@ class DictLoss(_Loss): ...@@ -23,6 +15,23 @@ class DictLoss(_Loss):
return self.loss(input[self.key], other[self.key]) return self.loss(input[self.key], other[self.key])
class _PerAtomDictLoss(DictLoss):
@staticmethod
def num_atoms(input):
ret = []
for s, c in zip(input['species'], input['coordinates']):
ret.append(torch.full((c.shape[0],), len(s),
dtype=c.dtype, device=c.device))
return torch.cat(ret)
def forward(self, input, other):
loss = self.loss(input[self.key], other[self.key])
loss /= self.num_atoms(input)
n = loss.numel()
return loss.sum() / n
class DictMetric(Metric): class DictMetric(Metric):
def __init__(self, key, metric): def __init__(self, key, metric):
...@@ -41,8 +50,11 @@ class DictMetric(Metric): ...@@ -41,8 +50,11 @@ class DictMetric(Metric):
return self.metric.compute() return self.metric.compute()
def MSELoss(key): def MSELoss(key, per_atom=True):
return DictLoss(key, torch.nn.MSELoss()) if per_atom:
return _PerAtomDictLoss(key, torch.nn.MSELoss(reduce=False))
else:
return DictLoss(key, torch.nn.MSELoss())
def RMSEMetric(key): def RMSEMetric(key):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment