Unverified Commit ea718be0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add common loss and metrics (#32)

parent e2bc9f29
......@@ -4,7 +4,6 @@ if sys.version_info.major >= 3:
import os
import unittest
import torch
from ignite.metrics import RootMeanSquaredError
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator
import torchani
......@@ -39,15 +38,13 @@ if sys.version_info.major >= 3:
nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
loss = torchani.ignite.DictLoss('energies', torch.nn.MSELoss())
optimizer = torch.optim.SGD(container.parameters(),
lr=0.001, momentum=0.8)
trainer = create_supervised_trainer(container, optimizer, loss)
trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
trainer.run(loader, max_epochs=10)
metric = torchani.ignite.DictMetric('energies',
RootMeanSquaredError())
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': metric
'RMSE': torchani.ignite.energy_rmse_metric
})
evaluator.run(loader)
......
from .container import Container
from .loss_metrics import DictLoss, DictMetric
from .loss_metrics import DictLoss, DictMetric, energy_mse_loss, \
energy_rmse_metric
__all__ = ['Container', 'DictLoss', 'DictMetric']
__all__ = ['Container', 'DictLoss', 'DictMetric', 'energy_mse_loss',
'energy_rmse_metric']
from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric
from ignite.metrics import RootMeanSquaredError
import torch
class DictLoss(_Loss):
......@@ -29,3 +31,7 @@ class DictMetric(Metric):
def compute(self):
self.metric.compute()
energy_mse_loss = DictLoss('energies', torch.nn.MSELoss())
energy_rmse_metric = DictMetric('energies', RootMeanSquaredError())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment