Unverified Commit ec79b239 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

Add files via upload

Different Optimizers in DeepSpeed.
parent 87c9fe3d
'''
Copyright 2019 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
'''
import types
import torch
import importlib
class FusedLamb(torch.optim.Optimizer):
"""Implements LAMB algorithm. Currently GPU-only. Requires DeepSpeed adapted Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
For usage example please see, TODO DeepSpeed Tutorial
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
https://arxiv.org/abs/1904.00962
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
max_grad_norm=0.,
max_coeff=10.0,
min_coeff=0.01,
amsgrad=False):
global fused_lamb_cuda
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
if amsgrad:
raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
max_coeff=max_coeff,
min_coeff=min_coeff)
super(FusedLamb, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self.lamb_coeffs = []
def step(self,
closure=None,
grads=None,
output_params=None,
scale=1.,
grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
if grads is None:
grads_group = [None] * len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0]) != list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None:
output_params_group = [None] * len(self.param_groups)
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0]) != list:
output_params_group = [output_params]
else:
output_params_group = output_params
if grad_norms is None:
grad_norms = [None] * len(self.param_groups)
#remove the previous coeffs
del self.lamb_coeffs[:]
for group, grads_this_group, output_params_this_group, grad_norm_group in zip(self.param_groups, grads_group, output_params_group, grad_norms):
if grads_this_group is None:
grads_this_group = [None] * len(group['params'])
if output_params_this_group is None:
output_params_this_group = [None] * len(group['params'])
if grad_norm_group is None:
grad_norm_group = [None] * len(group['params'])
elif not isinstance(grad_norm_group, list):
grad_norm_group = [grad_norm_group]
bias_correction = 1 if group['bias_correction'] else 0
for p, grad, output_param, grad_norm in zip(group['params'], grads_this_group, output_params_this_group, grad_norm_group):
# compute combined scale factor for this group
combined_scale = scale
if group['max_grad_norm'] > 0:
# norm is in fact norm*scale
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
if clip > 1:
combined_scale = clip * scale
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
max_coeff = group['max_coeff']
min_coeff = group['min_coeff']
state['step'] += 1
out_p = torch.tensor(
[],
dtype=torch.float) if output_param is None else output_param
lamb_coeff = fused_lamb_cuda.lamb(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
max_coeff,
min_coeff,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
self.lamb_coeffs.append(lamb_coeff)
return loss
def get_lamb_coeffs(self):
lamb_coeffs = [lamb_coeff.item() for lamb_coeff in self.lamb_coeffs]
return lamb_coeffs
This diff is collapsed.
'''
Copyright 2019 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex
'''
import torch
import logging
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
import math
class FP16_Optimizer(object):
"""
FP16 Optimizer for training fp16 models. Handles loss scaling.
For usage example please see, TODO: DeepSpeed V2 Tutorial
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
initial_dynamic_scale=2**32,
dynamic_loss_args=None,
verbose=True,
mpu=None,
clip_grad=0.0,
fused_adam_legacy=False):
self.fused_adam_legacy = fused_adam_legacy
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
# param flattened by groups
self.fp16_groups = []
self.fp16_groups_flat = []
self.fp32_groups_flat = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
self.fp16_groups.append(param_group['params'])
# init fp16 weight buffer, flattened
self.fp16_groups_flat.append(
_flatten_dense_tensors([p.clone().detach()
for p in self.fp16_groups[i]]))
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
# init master weight, flattened
self.fp32_groups_flat.append(
self.fp16_groups_flat[i].clone().float().detach())
# modify optimizer of have flat master weight
self.fp32_groups_flat[
i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.fp32_groups_flat[i]]
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
logging.warning("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = initial_dynamic_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
self.scale_window = 1000
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
self.verbose = verbose
self.clip_grad = clip_grad
self.norm_type = 2
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
else:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
#model parallel object
self.mpu = None
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def step_fused_adam(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
norm_groups = []
for i, group in enumerate(self.fp16_groups):
grads_groups_flat.append(
_flatten_dense_tensors([
torch.zeros(p.size(),
dtype=p.dtype,
device=p.device) if p.grad is None else p.grad
for p in group
]))
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow
combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
norm_groups,
apply_scale=False)
# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
output_params=[[p] for p in self.fp16_groups_flat],
scale=combined_scale,
grad_norms=norm_groups)
# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data
return self.overflow
def step(self, closure=None):
"""
Not supporting closure.
"""
if self.fused_adam_legacy:
return self.step_fused_adam()
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
norm_groups = []
for i, group in enumerate(self.fp16_groups):
data_type = self.fp32_groups_flat[i].dtype
grads_groups_flat.append(
_flatten_dense_tensors([
torch.zeros(p.size(),
dtype=data_type,
device=p.device)
if p.grad is None else p.grad.to(data_type) for p in group
]))
self.fp32_groups_flat[i].grad = grads_groups_flat[i]
norm_groups.append(get_grad_norm(self.fp32_groups_flat, mpu=self.mpu))
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow
self.unscale_and_clip_grads(grads_groups_flat, norm_groups)
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
for group in self.fp32_groups_flat:
group.grad = None
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data)
return self.overflow
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.cur_scale
if apply_scale:
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
return combined_scale
def backward(self, loss):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward()
def _update_scale(self, skip):
if self.dynamic_loss_scale:
if skip:
if self.verbose:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
else:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using static loss scale of", self.cur_scale)
self.cur_iter += 1
return
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
state_dict['clip_grad'] = self.clip_grad
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
self.clip_grad = state_dict['clip_grad']
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
current.data.copy_(saved.data)
'''
Copyright 2019 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex
'''
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
import math
import logging
class FP16_UnfusedOptimizer(object):
"""
FP16 Optimizer without weight fusion to support LAMB optimizer
For usage example please see, TODO: DeepSpeed V2 Tutorial
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
mpu=None,
clip_grad=0.0,
fused_lamb_legacy=False):
self.fused_lamb_legacy = fused_lamb_legacy
if torch.distributed.get_rank() == 0:
logging.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
# param groups
self.fp16_groups = []
self.fp32_groups = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
#fp16 weights that represents the actual model weights
self.fp16_groups.append(param_group['params'])
#creating a fp32 copy of the weights that will be updated first then
#copied to fp16 weights
fp32_group = [p.clone().float().detach() for p in param_group['params']]
#incase the internal optimizer needs it
for p in fp32_group:
p.requires_grad = True
#setting the param groups in the optimizer to point to fp32
#note these are not the weights used by the model
#the model uses the fp16 version that we added to fp16_group
self.fp32_groups.append(fp32_group)
param_group['params'] = self.fp32_groups[i]
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
raise SystemError("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = 1.0 * 2**16
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2.0
self.scale_window = 1000
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
self.verbose = verbose
self.clip_grad = clip_grad
self.norm_type = 2
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
else:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
self.mpu = None
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist outside of the step function
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def step_fused_lamb(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
grads_groups = []
norm_groups = []
for i, group in enumerate(self.fp16_groups):
grads_groups.append([p.grad for p in group])
grads_groups_flat.append(_flatten_dense_tensors(grads_groups[i]))
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow
combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False)
self.optimizer.step(grads=grads_groups,
output_params=self.fp16_groups,
scale=combined_scale)
return self.overflow
def step(self, closure=None):
"""
Not supporting closure.
"""
if self.fused_lamb_legacy:
return self.step_fused_lamb()
self.overflow = self.overflow_checker.check()
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow
norm_groups = []
for i, group in enumerate(self.fp16_groups):
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
# copying gradients to fp32 to work with fp32 parameters
for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]):
fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)
self.unscale_and_clip_grads(norm_groups)
self.optimizer.step()
for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
for fp32_param, fp16_param in zip(fp32_group, fp16_group):
#remove the fp32 grad
fp32_param.grad = None
#copy data from fp32 to fp16
fp16_param.data.copy_(fp32_param.data)
return self.overflow
def unscale_and_clip_grads(self, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.cur_scale
if apply_scale:
for group in self.fp32_groups:
for param in group:
if param.grad is not None:
param.grad.data.mul_(1. / combined_scale)
return combined_scale
def backward(self, loss):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward()
def _update_scale(self, skip):
if self.dynamic_loss_scale:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, 0.25)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
else:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using static loss scale of", self.cur_scale)
self.cur_iter += 1
return
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups'] = self.fp32_groups
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
# Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Taken and modified for DeepSpeed from:
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
import torch
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.
Args:
scale (float, optional, default=1.0): The loss scale.
"""
def __init__(self, scale=1):
self.cur_scale = scale
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
"""
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
return True
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether the gradient overflowed
def update_scale(self, overflow):
if not hasattr(self, 'min_scale'):
self.min_scale = 1
if not hasattr(self, 'delayed_shift'):
self.delayed_shift = 1
if not hasattr(self, 'cur_hysteresis'):
self.cur_hysteresis = 1
if not hasattr(self, 'consecutive_hysteresis'):
self.consecutive_hysteresis = True
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
if self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
TO-DO separate out into an example.
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]
learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment