Commit 38e82904 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update --memory-efficient-fp16 to work with c10d DDP

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/617

Differential Revision: D15555328

Pulled By: myleott

fbshipit-source-id: 35d1f329f887cb0b867c7a22f17a16f3c9c66815
parent 75cc8821
...@@ -100,6 +100,10 @@ class Adafactor(torch.optim.Optimizer): ...@@ -100,6 +100,10 @@ class Adafactor(torch.optim.Optimizer):
relative_step=relative_step, warmup_init=warmup_init) relative_step=relative_step, warmup_init=warmup_init)
super(Adafactor, self).__init__(params, defaults) super(Adafactor, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
def _get_lr(self, param_group, param_state): def _get_lr(self, param_group, param_state):
rel_step_sz = param_group['lr'] rel_step_sz = param_group['lr']
if param_group['relative_step']: if param_group['relative_step']:
...@@ -138,7 +142,7 @@ class Adafactor(torch.optim.Optimizer): ...@@ -138,7 +142,7 @@ class Adafactor(torch.optim.Optimizer):
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad.data.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Adafactor does not support sparse gradients.') raise RuntimeError('Adafactor does not support sparse gradients.')
...@@ -160,9 +164,18 @@ class Adafactor(torch.optim.Optimizer): ...@@ -160,9 +164,18 @@ class Adafactor(torch.optim.Optimizer):
state['exp_avg_sq'] = torch.zeros_like(grad) state['exp_avg_sq'] = torch.zeros_like(grad)
state['RMS'] = 0 state['RMS'] = 0
else:
state['exp_avg'] = state['exp_avg'].type_as(grad)
if factored:
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].type_as(grad)
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].type_as(grad)
else:
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(grad)
p_data_fp32 = p.data.float()
state['step'] += 1 state['step'] += 1
state['RMS'] = self._rms(p.data) state['RMS'] = self._rms(p_data_fp32)
group['lr'] = self._get_lr(group, state) group['lr'] = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
...@@ -192,8 +205,10 @@ class Adafactor(torch.optim.Optimizer): ...@@ -192,8 +205,10 @@ class Adafactor(torch.optim.Optimizer):
update = exp_avg update = exp_avg
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'] * group['lr'], p.data) p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-update)
p.data.add_(-update) p.data.copy_(p_data_fp32)
return loss return loss
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import types
import torch import torch
import torch.optim import torch.optim
...@@ -19,7 +21,7 @@ class FairseqAdam(FairseqOptimizer): ...@@ -19,7 +21,7 @@ class FairseqAdam(FairseqOptimizer):
super().__init__(args, params) super().__init__(args, params)
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam as _FusedAdam
self._optimizer = FusedAdam(params, **self.optimizer_config) self._optimizer = FusedAdam(params, **self.optimizer_config)
except ImportError: except ImportError:
self._optimizer = Adam(params, **self.optimizer_config) self._optimizer = Adam(params, **self.optimizer_config)
...@@ -87,6 +89,10 @@ class Adam(torch.optim.Optimizer): ...@@ -87,6 +89,10 @@ class Adam(torch.optim.Optimizer):
weight_decay=weight_decay, amsgrad=amsgrad) weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults) super(Adam, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -102,23 +108,30 @@ class Adam(torch.optim.Optimizer): ...@@ -102,23 +108,30 @@ class Adam(torch.optim.Optimizer):
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad.data.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad'] amsgrad = group['amsgrad']
p_data_fp32 = p.data.float()
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg'] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data) state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
if amsgrad:
state['max_exp_avg_sq'] = state['max_exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad: if amsgrad:
...@@ -143,8 +156,157 @@ class Adam(torch.optim.Optimizer): ...@@ -143,8 +156,157 @@ class Adam(torch.optim.Optimizer):
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'] * group['lr'], p.data) p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
return loss
class FusedAdam(torch.optim.Optimizer):
"""
Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Compared to the original version in Apex, the fairseq version casts grads
and params to FP32 internally to support ``--memory-efficient-fp16``.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
global fused_adam_cuda
import importlib
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
@property
def supports_memory_efficient_fp16(self):
return True
def step(self, closure=None, grads=None, scale=1., grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads
if grad_norms is None:
grad_norms = [None]*len(self.param_groups)
for group, grads_this_group, grad_norm in zip(self.param_groups, grads_group, grad_norms):
if grads_this_group is None:
grads_this_group = [None]*len(group['params'])
# compute combined scale factor for this group
combined_scale = scale
if group['max_grad_norm'] > 0:
# norm is in fact norm*scale
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
if clip > 1:
combined_scale = clip * scale
bias_correction = 1 if group['bias_correction'] else 0
for p, grad in zip(group['params'], grads_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
p_data_fp32 = p.data.float()
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
p.data.addcdiv_(-step_size, exp_avg, denom) out_p = p.data
fused_adam_cuda.adam(p_data_fp32,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
...@@ -96,3 +96,9 @@ class FairseqOptimizer(object): ...@@ -96,3 +96,9 @@ class FairseqOptimizer(object):
for p in group['params']: for p in group['params']:
p.grad = None p.grad = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
@property
def supports_memory_efficient_fp16(self):
if hasattr(self.optimizer, 'supports_memory_efficient_fp16'):
return self.optimizer.supports_memory_efficient_fp16
return False
...@@ -210,65 +210,28 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -210,65 +210,28 @@ class FP16Optimizer(optim.FairseqOptimizer):
self._needs_sync = False self._needs_sync = False
class ConvertToFP32(object):
"""
A wrapper around a list of params that will convert them to FP32 on the
first iteration, after which this essentially behaves like a normal list.
"""
def __init__(self, params):
def convert_to_fp32(p):
p.data = p.data.float()
if p.grad is not None:
p.grad.data = p.grad.data.float()
return p
assert isinstance(params, list)
self.params = params
self.itr = map(convert_to_fp32, params)
@staticmethod
def wrap_optimizer_(optimizer):
for group in optimizer.param_groups:
group['params'] = ConvertToFP32(group['params'])
@staticmethod
def unwrap_optimizer_(optimizer):
for group in optimizer.param_groups:
group['params'] = group['params'].params # unwrap from ConvertToFP32
for p in group['params']:
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
def __len__(self):
return len(self.params)
def __iter__(self):
if self.itr is not None:
return self
else:
return iter(self.params)
def __next__(self):
try:
return next(self.itr)
except StopIteration:
self.itr = None
raise StopIteration
class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
""" """
Wrap an *optimizer* to support FP16 (mixed precision) training. Wrap an *optimizer* to support FP16 (mixed precision) training.
Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not
memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff maintain an FP32 copy of the model. We instead expect the optimizer to
is reduced optimization speed, which can be mitigated with `--update-freq`. convert the gradients to FP32 internally and sync the results back to the
FP16 model params. This significantly reduces memory usage but slightly
increases the time spent in the optimizer.
Since this wrapper depends on specific functionality in the wrapped
optimizer (i.e., on-the-fly conversion of grads to FP32), only certain
optimizers can be wrapped. This is determined by the
*supports_memory_efficient_fp16* property.
""" """
def __init__(self, args, params, optimizer): def __init__(self, args, params, optimizer):
if not optimizer.supports_memory_efficient_fp16:
raise ValueError(
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
)
super().__init__(args, params) super().__init__(args, params)
self.wrapped_optimizer = optimizer self.wrapped_optimizer = optimizer
...@@ -329,9 +292,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): ...@@ -329,9 +292,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
""" """
if 'loss_scale' in state_dict: if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale'] self.scaler.loss_scale = state_dict['loss_scale']
ConvertToFP32.wrap_optimizer_(self.wrapped_optimizer.optimizer)
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss): def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. """Computes the sum of gradients of the given tensor w.r.t. graph leaves.
...@@ -384,15 +345,8 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): ...@@ -384,15 +345,8 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step.""" """Performs a single optimization step."""
self._unscale_grads() self._unscale_grads()
# convert params and grads to FP32 (lazily)
ConvertToFP32.wrap_optimizer_(self.wrapped_optimizer.optimizer)
self.wrapped_optimizer.step(closure) self.wrapped_optimizer.step(closure)
# convert params back to FP16
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
self.wrapped_optimizer.zero_grad() self.wrapped_optimizer.zero_grad()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer, required
from . import FairseqOptimizer, register_optimizer from . import FairseqOptimizer, register_optimizer
...@@ -46,6 +47,10 @@ class NAG(Optimizer): ...@@ -46,6 +47,10 @@ class NAG(Optimizer):
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
super(NAG, self).__init__(params, defaults) super(NAG, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -68,20 +73,26 @@ class NAG(Optimizer): ...@@ -68,20 +73,26 @@ class NAG(Optimizer):
if p.grad is None: if p.grad is None:
continue continue
d_p = p.grad.data p_data_fp32 = p.data.float()
d_p = p.grad.data.float()
param_state = self.state[p] param_state = self.state[p]
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = d_p.clone().zero_() param_state['momentum_buffer'] = torch.zeros_like(d_p)
else:
param_state['momentum_buffer'] = param_state['momentum_buffer'].type_as(d_p)
buf = param_state['momentum_buffer'] buf = param_state['momentum_buffer']
if weight_decay != 0: if weight_decay != 0:
p.data.mul_(1 - lr * weight_decay) p_data_fp32.mul_(1 - lr * weight_decay)
p.data.add_(momentum * momentum * lr_correct, buf) p_data_fp32.add_(momentum * momentum * lr_correct, buf)
p.data.add_(-(1 + momentum) * lr, d_p) p_data_fp32.add_(-(1 + momentum) * lr, d_p)
buf.mul_(momentum * lr_correct).add_(-lr, d_p) buf.mul_(momentum * lr_correct).add_(-lr, d_p)
p.data.copy_(p_data_fp32)
group['lr_old'] = lr group['lr_old'] = lr
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment