Unverified Commit ea718be0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add common loss and metrics (#32)

parent e2bc9f29
...@@ -4,7 +4,6 @@ if sys.version_info.major >= 3: ...@@ -4,7 +4,6 @@ if sys.version_info.major >= 3:
import os import os
import unittest import unittest
import torch import torch
from ignite.metrics import RootMeanSquaredError
from ignite.engine import create_supervised_trainer, \ from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator create_supervised_evaluator
import torchani import torchani
...@@ -39,15 +38,13 @@ if sys.version_info.major >= 3: ...@@ -39,15 +38,13 @@ if sys.version_info.major >= 3:
nnp = Flatten(nnp) nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
loss = torchani.ignite.DictLoss('energies', torch.nn.MSELoss())
optimizer = torch.optim.SGD(container.parameters(), optimizer = torch.optim.SGD(container.parameters(),
lr=0.001, momentum=0.8) lr=0.001, momentum=0.8)
trainer = create_supervised_trainer(container, optimizer, loss) trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
trainer.run(loader, max_epochs=10) trainer.run(loader, max_epochs=10)
metric = torchani.ignite.DictMetric('energies',
RootMeanSquaredError())
evaluator = create_supervised_evaluator(container, metrics={ evaluator = create_supervised_evaluator(container, metrics={
'RMSE': metric 'RMSE': torchani.ignite.energy_rmse_metric
}) })
evaluator.run(loader) evaluator.run(loader)
......
from .container import Container from .container import Container
from .loss_metrics import DictLoss, DictMetric from .loss_metrics import DictLoss, DictMetric, energy_mse_loss, \
energy_rmse_metric
__all__ = ['Container', 'DictLoss', 'DictMetric'] __all__ = ['Container', 'DictLoss', 'DictMetric', 'energy_mse_loss',
'energy_rmse_metric']
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric from ignite.metrics.metric import Metric
from ignite.metrics import RootMeanSquaredError
import torch
class DictLoss(_Loss): class DictLoss(_Loss):
...@@ -29,3 +31,7 @@ class DictMetric(Metric): ...@@ -29,3 +31,7 @@ class DictMetric(Metric):
def compute(self): def compute(self):
self.metric.compute() self.metric.compute()
energy_mse_loss = DictLoss('energies', torch.nn.MSELoss())
energy_rmse_metric = DictMetric('energies', RootMeanSquaredError())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment