"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "abe36e2e5fb9104eca2456945e12b49f93fce475"
Unverified Commit e2bc9f29 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add DictMetric (#31)

parent 10699bf7
...@@ -4,7 +4,9 @@ if sys.version_info.major >= 3: ...@@ -4,7 +4,9 @@ if sys.version_info.major >= 3:
import os import os
import unittest import unittest
import torch import torch
from ignite.engine import create_supervised_trainer from ignite.metrics import RootMeanSquaredError
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator
import torchani import torchani
import torchani.data import torchani.data
...@@ -37,11 +39,17 @@ if sys.version_info.major >= 3: ...@@ -37,11 +39,17 @@ if sys.version_info.major >= 3:
nnp = Flatten(nnp) nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
loss = torchani.ignite.DictLosses({'energies': torch.nn.MSELoss()}) loss = torchani.ignite.DictLoss('energies', torch.nn.MSELoss())
optimizer = torch.optim.SGD(container.parameters(), optimizer = torch.optim.SGD(container.parameters(),
lr=0.001, momentum=0.8) lr=0.001, momentum=0.8)
trainer = create_supervised_trainer(container, optimizer, loss) trainer = create_supervised_trainer(container, optimizer, loss)
trainer.run(loader, max_epochs=10) trainer.run(loader, max_epochs=10)
metric = torchani.ignite.DictMetric('energies',
RootMeanSquaredError())
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': metric
})
evaluator.run(loader)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from .container import Container from .container import Container
from .dict_loss import DictLosses from .loss_metrics import DictLoss, DictMetric
__all__ = ['Container', 'DictLosses'] __all__ = ['Container', 'DictLoss', 'DictMetric']
from torch.nn.modules.loss import _Loss
class DictLosses(_Loss):
def __init__(self, losses):
super(DictLosses, self).__init__()
self.losses = losses
def forward(self, input, other):
total = 0
for i in self.losses:
total += self.losses[i](input[i], other[i])
return total
from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric
class DictLoss(_Loss):
def __init__(self, key, loss):
super(DictLoss, self).__init__()
self.key = key
self.loss = loss
def forward(self, input, other):
return self.loss(input[self.key], other[self.key])
class DictMetric(Metric):
def __init__(self, key, metric):
self.key = key
self.metric = metric
super(DictMetric, self).__init__()
def reset(self):
self.metric.reset()
def update(self, output):
y_pred, y = output
self.metric.update((y_pred[self.key], y[self.key]))
def compute(self):
self.metric.compute()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment