Unverified Commit b63f6e40 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add TransformedLoss to allow adding non-linearity to computed loss (#61)

parent b16a3234
...@@ -34,8 +34,11 @@ if sys.version_info.major >= 3: ...@@ -34,8 +34,11 @@ if sys.version_info.major >= 3:
model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten()) model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
container = torchani.ignite.Container({'energies': model}) container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters()) optimizer = torch.optim.Adam(container.parameters())
loss = torchani.ignite.TransformedLoss(
torchani.ignite.MSELoss('energies'),
lambda x: torch.exp(x) - 1)
trainer = create_supervised_trainer( trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies')) container, optimizer, loss)
evaluator = create_supervised_evaluator(container, metrics={ evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies') 'RMSE': torchani.ignite.RMSEMetric('energies')
}) })
......
from .container import Container from .container import Container
from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric, \
TransformedLoss
__all__ = ['Container', 'DictLoss', 'DictMetric', 'MSELoss', 'RMSEMetric'] __all__ = ['Container', 'DictLoss', 'DictMetric', 'MSELoss', 'RMSEMetric',
'TransformedLoss']
...@@ -57,5 +57,16 @@ def MSELoss(key, per_atom=True): ...@@ -57,5 +57,16 @@ def MSELoss(key, per_atom=True):
return DictLoss(key, torch.nn.MSELoss()) return DictLoss(key, torch.nn.MSELoss())
class TransformedLoss(_Loss):
def __init__(self, origin, transform):
super(TransformedLoss, self).__init__()
self.origin = origin
self.transform = transform
def forward(self, input, other):
return self.transform(self.origin(input, other))
def RMSEMetric(key): def RMSEMetric(key):
return DictMetric(key, RootMeanSquaredError()) return DictMetric(key, RootMeanSquaredError())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment