"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "83b0b49ec9110454f46dbce0a5695b7be96bc52d"
Commit 8bafae2e authored by Myle Ott's avatar Myle Ott
Browse files

Better logging from criterions

parent e432459b
......@@ -21,11 +21,16 @@ class CrossEntropyCriterion(FairseqCriterion):
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
def forward(self, model, sample, grad_denom):
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return loss
def aggregate(self, losses):
return sum(losses) / math.log(2)
return {
'loss': loss / grad_denom,
}
def aggregate(self, loss_dicts):
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
}
......@@ -18,18 +18,10 @@ class FairseqCriterion(_Loss):
"""Gradient normalization term for DataParallel training."""
raise NotImplementedError
def prepare(self, model, sample):
"""Apply criterion-specific modifications to the sample."""
return sample
def forward(self, net_output, sample):
def forward(self, model, sample, grad_denom):
"""Compute the loss for the given sample and network output."""
raise NotImplementedError
def aggregate(self, losses):
"""Aggregate losses from DataParallel training.
Takes a list of losses as input (as returned by forward) and
aggregates them into the total loss for the mini-batch.
"""
def aggregate(self, losses, log_infos):
"""Aggregate losses from DataParallel training."""
raise NotImplementedError
......@@ -52,11 +52,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, net_output, sample):
def forward(self, model, sample, grad_denom):
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return loss
def aggregate(self, losses):
return sum(losses) / math.log(2)
return {
'loss': loss / grad_denom,
}
def aggregate(self, loss_dicts):
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
}
......@@ -146,10 +146,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
]
# aggregate losses and gradient norms
losses, grad_norms = Future.gen_tuple_list(losses)
loss = self.criterion.aggregate(losses)
loss_dicts = Future.gen_list(losses)
loss_dict = self.criterion.aggregate(loss_dicts)
loss_dict['gnorm'] = loss_dicts[0]['gnorm']
return loss, grad_norms[0]
return loss_dict
def _async_train_step(self, rank, device_id, grad_denom):
self.model.train()
......@@ -159,14 +160,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# calculate loss and grads
loss = 0
loss_dict = {}
if self._sample is not None:
self._sample = self.criterion.prepare(self.model, self._sample)
net_output = self.model(**self._sample['net_input'])
loss_ = self.criterion(net_output, self._sample)
if grad_denom is not None:
loss_ /= grad_denom
loss_.backward()
loss = loss_.data[0]
loss_dict = self.criterion(self.model, self._sample, grad_denom)
loss_dict['loss'].backward()
loss = loss_dict['loss'].data[0]
# flatten grads into a contiguous block of memory
if self.flat_grads is None:
......@@ -176,12 +174,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
nccl.all_reduce(self.flat_grads)
# clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
loss_dict['gnorm'] = self._clip_grads_(self.flat_grads, self.args.clip_norm)
# take an optimization step
self.optimizer.step()
return loss, grad_norm
return loss_dict
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
......@@ -218,20 +216,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
]
# aggregate losses
loss = self.criterion.aggregate(Future.gen_list(losses))
loss_dict = self.criterion.aggregate(Future.gen_list(losses))
return loss
return loss_dict
def _async_valid_step(self, rank, device_id, grad_denom):
if self._sample is None:
return 0
return {}
self.model.eval()
self._sample = self.criterion.prepare(self.model, self._sample)
net_output = self.model(**self._sample['net_input'])
loss = self.criterion(net_output, self._sample)
if grad_denom is not None:
loss /= grad_denom
return loss.data[0]
return self.criterion(self.model, self._sample, grad_denom)
def get_lr(self):
"""Get the current learning rate."""
......
......@@ -115,13 +115,15 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second
clip_meter = AverageMeter() # % of updates clipped
gnorm_meter = AverageMeter() # gradient norm
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss, grad_norm = trainer.train_step(sample)
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample)
......@@ -129,8 +131,12 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
bsz_meter.update(src_size)
wpb_meter.update(ntokens)
wps_meter.update(ntokens)
clip_meter.update(1 if grad_norm > args.clip_norm else 0)
gnorm_meter.update(grad_norm)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
......@@ -139,8 +145,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
('bsz', '{:5d}'.format(round(bsz_meter.avg))),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
]), refresh=False)
] + extra_postfix), refresh=False)
if i == 0:
# ignore the first mini-batch in words-per-second calculation
......@@ -148,16 +153,17 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
trainer.save_checkpoint(args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
t.write(fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
round(wps_meter.elapsed_time),
round(wps_meter.avg),
round(wpb_meter.avg),
round(bsz_meter.avg),
lr, clip_meter.avg * 100,
gnorm_meter.avg))
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, math.pow(2, loss_meter.avg))
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
round(bsz_meter.avg), lr, clip_meter.avg * 100)
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
def validate(args, epoch, trainer, dataset, subset, ngpus):
......@@ -168,18 +174,35 @@ def validate(args, epoch, trainer, dataset, subset, ngpus):
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
with progress_bar(itr, desc, leave=False) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
loss = trainer.valid_step(sample)
loss_meter.update(loss, ntokens)
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)),
] + extra_postfix), refresh=False)
val_loss = loss_meter.avg
t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'
.format(val_loss, math.pow(2, val_loss)))
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
val_loss, math.pow(2, val_loss))
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
# update and return the learning rate
return val_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment