Commit 03a57dec authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Improve memory efficiency of FP16 optimization (#404)

Summary:
Previously when training with --fp16, we stored a copy of the model parameters in FP32 for optimization, which consumed a lot of memory. An alternative is to just do the conversions to FP32 on the fly, which allows the caching allocator to reuse/save some memory.

This reduces peak memory usage by ~20% with a negligible reduction in training speed (~2% slower) when training a big transformer on 8 GPUs on wmt en-de with --update-freq=16.

This does not affect convergence, i.e., models will train exactly as they did before.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/404

Differential Revision: D13394376

Pulled By: myleott

fbshipit-source-id: 2b9f808548df4782110513c9cfc9f7c6159bcbbf
parent 0f833526
......@@ -10,7 +10,7 @@ import torch
from fairseq import optim, utils
class DynamicLossScaler:
class DynamicLossScaler(object):
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000, tolerance=0.05):
self.loss_scale = init_scale
......@@ -45,12 +45,54 @@ class DynamicLossScaler:
return False
class ConvertToFP32(object):
"""
A wrapper around a list of params that will convert them to FP32 on the
first iteration, after which this essentially behaves like a normal list.
"""
def __init__(self, params):
def convert_to_fp32(p):
p.data = p.data.float()
if p.grad is not None:
p.grad.data = p.grad.data.float()
return p
assert isinstance(params, list)
self.params = params
self.itr = map(convert_to_fp32, params)
def __len__(self):
return len(self.params)
def __iter__(self):
if self.itr is not None:
return self
else:
return iter(self.params)
def __next__(self):
try:
return next(self.itr)
except StopIteration:
self.itr = None
raise StopIteration
class FP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
def __init__(self, args, params, optimizer):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.wrapped_optimizer = optimizer
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
......@@ -70,37 +112,26 @@ class FP16Optimizer(optim.FairseqOptimizer):
@staticmethod
def build_optimizer(args, params):
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return FP16Optimizer(args, params, fp32_optimizer, fp32_params)
fp16_optimizer = optim.build_optimizer(args, params)
return FP16Optimizer(args, params, fp16_optimizer)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
return self.wrapped_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
return self.wrapped_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
return self.wrapped_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
self.wrapped_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict = self.wrapped_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
......@@ -114,41 +145,33 @@ class FP16Optimizer(optim.FairseqOptimizer):
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
self._grads_are_scaled = True
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
def _unscale_grads(self, multiply_grads=1.):
if self._grads_are_scaled:
self._grads_are_scaled = False
self._needs_sync = False
# correct for dynamic loss scaler
self.wrapped_optimizer.multiply_grads(multiply_grads / self.scaler.loss_scale)
else:
assert multiply_grads == 1.
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
if self._grads_are_scaled:
self._unscale_grads(c)
else:
self.fp32_params.grad.data.mul_(c)
self.wrapped_optimizer.multiply_grads(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
self._unscale_grads()
grad_norm = self.wrapped_optimizer.clip_grad_norm(max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
......@@ -163,27 +186,28 @@ class FP16Optimizer(optim.FairseqOptimizer):
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
self._unscale_grads()
# convert params and grads to FP32 (lazily)
for group in self.wrapped_optimizer.optimizer.param_groups:
group['params'] = ConvertToFP32(group['params'])
self.wrapped_optimizer.step(closure)
# convert params back to FP16
for group in self.wrapped_optimizer.optimizer.param_groups:
group['params'] = group['params'].params # unwrap from ConvertToFP32
for p in group['params']:
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
self.wrapped_optimizer.zero_grad()
self._grads_are_scaled = False
......@@ -214,6 +214,7 @@ class Trainer(object):
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
print('| WARNING: OOM in all workers, skipping update')
self.zero_grad()
......@@ -256,7 +257,6 @@ class Trainer(object):
self.meters['clip'].update(
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
)
self.meters['oom'].update(ooms)
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
......
......@@ -378,14 +378,6 @@ def item(tensor):
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment