"vscode:/vscode.git/clone" did not exist on "6ea83608adcf9302d3d24733dc319ab4ea9607ad"
Commit 3f970086 authored by Myle Ott's avatar Myle Ott
Browse files

More flexible gradient normalization

parent 88a8bd42
......@@ -18,19 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion):
super().__init__()
self.padding_idx = padding_idx
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, model, sample, grad_denom):
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return {
'loss': loss / grad_denom,
sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def aggregate(self, loss_dicts):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
......@@ -14,14 +14,22 @@ class FairseqCriterion(_Loss):
def __init__(self):
super().__init__()
def grad_denom(self, samples):
"""Gradient normalization term for DataParallel training."""
raise NotImplementedError
def forward(self, model, sample):
"""Compute the loss for the given sample.
def forward(self, model, sample, grad_denom):
"""Compute the loss for the given sample and network output."""
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
def aggregate(self, losses, log_infos):
"""Aggregate losses from DataParallel training."""
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError
@staticmethod
def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes."""
return sum(sample_sizes)
......@@ -49,19 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.padding_idx = padding_idx
self.weights = weights
def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)
def forward(self, model, sample, grad_denom):
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return {
'loss': loss / grad_denom,
sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def aggregate(self, loss_dicts):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
......@@ -15,7 +15,6 @@ import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils
from fairseq.criterions import FairseqCriterion
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
......@@ -74,6 +73,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self.flat_grads = None
self.loss = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
......@@ -136,35 +136,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
])
# forward pass, backward pass and gradient step
losses = [
self.call_async(rank, '_async_train_step', grad_denom=grad_denom)
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
]
])
# aggregate losses and gradient norms
loss_dicts = Future.gen_list(losses)
loss_dict = self.criterion.aggregate(loss_dicts)
loss_dict['gnorm'] = loss_dicts[0]['gnorm']
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
return loss_dict
return logging_output
def _async_train_step(self, rank, device_id, grad_denom):
def _async_forward(self, rank, device_id, eval=False):
if eval:
self.model.eval()
else:
self.model.train()
# zero grads even if self._sample is None, since we will all-reduce them
self.optimizer.zero_grad()
# calculate loss and grads
loss = 0
loss_dict = {}
if self._sample is not None:
loss_dict = self.criterion(self.model, self._sample, grad_denom)
loss_dict['loss'].backward()
loss = loss_dict['loss'].data[0]
if self._sample is None:
return 0, {}
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
return sample_size, logging_output
def _async_backward_and_opt(self, rank, device_id, grad_denom):
if self.loss is not None:
# backward pass
self.loss.backward()
# flatten grads into a contiguous block of memory
if self.flat_grads is None:
......@@ -173,13 +182,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# all-reduce grads
nccl.all_reduce(self.flat_grads)
# normalize grads
if grad_denom != 0:
self.flat_grads.div_(grad_denom)
# clip grads
loss_dict['gnorm'] = self._clip_grads_(self.flat_grads, self.args.clip_norm)
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
# take an optimization step
self.optimizer.step()
return loss_dict
# reset loss
self.loss = None
return grad_norm
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
......@@ -206,25 +222,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
self._scatter_samples(samples, volatile=True)
# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass
losses = [
self.call_async(rank, '_async_valid_step', grad_denom=grad_denom)
_sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
]
# aggregate losses
loss_dict = self.criterion.aggregate(Future.gen_list(losses))
])
return loss_dict
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
def _async_valid_step(self, rank, device_id, grad_denom):
if self._sample is None:
return {}
self.model.eval()
return self.criterion(self.model, self._sample, grad_denom)
return logging_output
def get_lr(self):
"""Get the current learning rate."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment