Commit 8a8df81d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Deprecate _aggregate_logging_outputs

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/498

Differential Revision: D14024524

Pulled By: myleott

fbshipit-source-id: 1b0be4bb212dbab41ea0959ac34020832ff00645
parent f296824f
...@@ -16,7 +16,7 @@ CRITERION_CLASS_NAMES = set() ...@@ -16,7 +16,7 @@ CRITERION_CLASS_NAMES = set()
def build_criterion(args, task): def build_criterion(args, task):
return CRITERION_REGISTRY[args.criterion](args, task) return CRITERION_REGISTRY[args.criterion].build_criterion(args, task)
def register_criterion(name): def register_criterion(name):
......
...@@ -24,19 +24,23 @@ class CompositeLoss(FairseqCriterion): ...@@ -24,19 +24,23 @@ class CompositeLoss(FairseqCriterion):
help='underlying criterion to use for the composite loss') help='underlying criterion to use for the composite loss')
# fmt: on # fmt: on
def __init__(self, args, task): @staticmethod
super().__init__(args, task) def build_underlying_criterion(args, task):
saved_criterion = args.criterion saved_criterion = args.criterion
args.criterion = args.underlying_criterion args.criterion = args.underlying_criterion
assert saved_criterion != args.underlying_criterion assert saved_criterion != args.underlying_criterion
underlying_criterion = task.build_criterion(args)
self.underlying_criterion = task.build_criterion(args)
args.criterion = saved_criterion args.criterion = saved_criterion
return underlying_criterion
@classmethod
def build_criterion(cls, args, task):
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
class FakeModel(nn.Module): class FakeModel(nn.Module):
def __init__(self, model, net_out, target): def __init__(self, model, net_out, target):
super(CompositeLoss.FakeModel, self).__init__() super().__init__()
self.model = model self.model = model
self.net_out = net_out self.net_out = net_out
self.target = target self.target = target
...@@ -44,6 +48,9 @@ class CompositeLoss(FairseqCriterion): ...@@ -44,6 +48,9 @@ class CompositeLoss(FairseqCriterion):
def forward(self, **unused): def forward(self, **unused):
return self.net_out return self.net_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
return self.model.get_normalized_probs(net_output, log_probs, sample=sample)
def get_targets(self, *unused): def get_targets(self, *unused):
return self.target return self.target
...@@ -51,17 +58,24 @@ class CompositeLoss(FairseqCriterion): ...@@ -51,17 +58,24 @@ class CompositeLoss(FairseqCriterion):
def decoder(self): def decoder(self):
return self.model.decoder return self.model.decoder
class _CompositeLoss(FairseqCriterion):
def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
self.underlying_criterion = underlying_criterion
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
net_outputs = model(**sample['net_input']) net_outputs = model(**sample['net_input'])
targets = sample['target'] targets = sample['target']
bsz = targets[0].size(0) bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).zero_() loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
sample_size = 0 sample_size = 0
logging_output = {} logging_output = {}
for o, t in zip(net_outputs[0], targets): for o, t in zip(net_outputs[0], targets):
m = CompositeLoss.FakeModel(model, (o, net_outputs[1]), t) m = FakeModel(model, (o, net_outputs[1]), t)
sample['target'] = t
l, ss, logging_output = self.underlying_criterion(m, sample, reduce) l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l loss += l
sample_size += ss sample_size += ss
...@@ -72,5 +86,8 @@ class CompositeLoss(FairseqCriterion): ...@@ -72,5 +86,8 @@ class CompositeLoss(FairseqCriterion):
logging_output['loss'] = utils.item(loss.data) if reduce else loss.data logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output return loss, sample_size, logging_output
def _aggregate_logging_outputs(self, logging_outputs): @staticmethod
return self.underlying_criterion._aggregate_logging_outputs(logging_outputs) def aggregate_logging_outputs(logging_outputs):
return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs)
return _CompositeLoss(args, task, underlying_criterion)
...@@ -20,6 +20,10 @@ class FairseqCriterion(_Loss): ...@@ -20,6 +20,10 @@ class FairseqCriterion(_Loss):
"""Add criterion-specific arguments to the parser.""" """Add criterion-specific arguments to the parser."""
pass pass
@classmethod
def build_criterion(cls, args, task):
return cls(args, task)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -35,15 +39,6 @@ class FairseqCriterion(_Loss): ...@@ -35,15 +39,6 @@ class FairseqCriterion(_Loss):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
raise NotImplementedError raise NotImplementedError
def _aggregate_logging_outputs(self, logging_outputs):
"""An instance method version of :func:`aggregate_logging_outputs`.
This can be overridden if needed, but please be careful not to rely
on shared state when aggregating logging outputs otherwise you may
get incorrect results.
"""
return self.__class__.aggregate_logging_outputs(logging_outputs)
@staticmethod @staticmethod
def grad_denom(sample_sizes): def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes.""" """Compute the gradient denominator for a set of sample sizes."""
......
...@@ -245,7 +245,7 @@ class FairseqTask(object): ...@@ -245,7 +245,7 @@ class FairseqTask(object):
return criterion.__class__.grad_denom(sample_sizes) return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion): def aggregate_logging_outputs(self, logging_outputs, criterion):
return criterion._aggregate_logging_outputs(logging_outputs) return criterion.__class__.aggregate_logging_outputs(logging_outputs)
def max_positions(self): def max_positions(self):
"""Return the max input length allowed by the task.""" """Return the max input length allowed by the task."""
......
...@@ -222,7 +222,7 @@ class TranslationMoETask(TranslationTask): ...@@ -222,7 +222,7 @@ class TranslationMoETask(TranslationTask):
) )
def aggregate_logging_outputs(self, logging_outputs, criterion): def aggregate_logging_outputs(self, logging_outputs, criterion):
agg_logging_outputs = criterion._aggregate_logging_outputs(logging_outputs) agg_logging_outputs = criterion.__class__.aggregate_logging_outputs(logging_outputs)
agg_logging_outputs['posterior'] = sum( agg_logging_outputs['posterior'] = sum(
log['posterior'] for log in logging_outputs if 'posterior' in log log['posterior'] for log in logging_outputs if 'posterior' in log
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment