"src/diffusers/commands/diffusers_cli.py" did not exist on "27d11a0094e292a8d790714d1b5cdf5e9186814d"
Commit 8a8df81d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Deprecate _aggregate_logging_outputs

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/498

Differential Revision: D14024524

Pulled By: myleott

fbshipit-source-id: 1b0be4bb212dbab41ea0959ac34020832ff00645
parent f296824f
......@@ -16,7 +16,7 @@ CRITERION_CLASS_NAMES = set()
def build_criterion(args, task):
return CRITERION_REGISTRY[args.criterion](args, task)
return CRITERION_REGISTRY[args.criterion].build_criterion(args, task)
def register_criterion(name):
......
......@@ -24,53 +24,70 @@ class CompositeLoss(FairseqCriterion):
help='underlying criterion to use for the composite loss')
# fmt: on
def __init__(self, args, task):
super().__init__(args, task)
@staticmethod
def build_underlying_criterion(args, task):
saved_criterion = args.criterion
args.criterion = args.underlying_criterion
assert saved_criterion != args.underlying_criterion
self.underlying_criterion = task.build_criterion(args)
underlying_criterion = task.build_criterion(args)
args.criterion = saved_criterion
return underlying_criterion
@classmethod
def build_criterion(cls, args, task):
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
class FakeModel(nn.Module):
def __init__(self, model, net_out, target):
super().__init__()
self.model = model
self.net_out = net_out
self.target = target
def forward(self, **unused):
return self.net_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
return self.model.get_normalized_probs(net_output, log_probs, sample=sample)
def get_targets(self, *unused):
return self.target
class FakeModel(nn.Module):
def __init__(self, model, net_out, target):
super(CompositeLoss.FakeModel, self).__init__()
self.model = model
self.net_out = net_out
self.target = target
@property
def decoder(self):
return self.model.decoder
def forward(self, **unused):
return self.net_out
class _CompositeLoss(FairseqCriterion):
def get_targets(self, *unused):
return self.target
def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
self.underlying_criterion = underlying_criterion
@property
def decoder(self):
return self.model.decoder
def forward(self, model, sample, reduce=True):
net_outputs = model(**sample['net_input'])
targets = sample['target']
def forward(self, model, sample, reduce=True):
net_outputs = model(**sample['net_input'])
targets = sample['target']
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).zero_()
sample_size = 0
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = FakeModel(model, (o, net_outputs[1]), t)
sample['target'] = t
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
sample_size = 0
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = CompositeLoss.FakeModel(model, (o, net_outputs[1]), t)
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
loss.div_(len(targets))
sample_size /= len(targets)
loss.div_(len(targets))
sample_size /= len(targets)
logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs)
def _aggregate_logging_outputs(self, logging_outputs):
return self.underlying_criterion._aggregate_logging_outputs(logging_outputs)
return _CompositeLoss(args, task, underlying_criterion)
......@@ -20,6 +20,10 @@ class FairseqCriterion(_Loss):
"""Add criterion-specific arguments to the parser."""
pass
@classmethod
def build_criterion(cls, args, task):
return cls(args, task)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......@@ -35,15 +39,6 @@ class FairseqCriterion(_Loss):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError
def _aggregate_logging_outputs(self, logging_outputs):
"""An instance method version of :func:`aggregate_logging_outputs`.
This can be overridden if needed, but please be careful not to rely
on shared state when aggregating logging outputs otherwise you may
get incorrect results.
"""
return self.__class__.aggregate_logging_outputs(logging_outputs)
@staticmethod
def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes."""
......
......@@ -245,7 +245,7 @@ class FairseqTask(object):
return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion):
return criterion._aggregate_logging_outputs(logging_outputs)
return criterion.__class__.aggregate_logging_outputs(logging_outputs)
def max_positions(self):
"""Return the max input length allowed by the task."""
......
......@@ -222,7 +222,7 @@ class TranslationMoETask(TranslationTask):
)
def aggregate_logging_outputs(self, logging_outputs, criterion):
agg_logging_outputs = criterion._aggregate_logging_outputs(logging_outputs)
agg_logging_outputs = criterion.__class__.aggregate_logging_outputs(logging_outputs)
agg_logging_outputs['posterior'] = sum(
log['posterior'] for log in logging_outputs if 'posterior' in log
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment