Commit 88effd5d authored by Raul Puri's avatar Raul Puri Committed by mcarilli
Browse files

LARC clipping+documentation (#6)

* Proper implementation of LARC clipping
 * Documentation of LARC class
 * Modification of FP16_Optimizer to absorb optimizer instance that's being wrapped instead of creating new optimizer instance of same class.
parent 34582381
......@@ -184,7 +184,7 @@ class FP16_Optimizer(object):
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.optimizer = init_optimizer.__class__(init_optimizer.param_groups)
self.optimizer = optimizer
if dynamic_loss_scale:
self.dynamic_loss_scale = True
......
......@@ -4,11 +4,45 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter
class LARC(object):
def __init__(self, optimizer, trust_coefficient=0.02, epsilon=1e-8):
"""
:class:`LARC` is a pytorch implementation of both the scaling and clipping varients of LARC,
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
local learning rate for each individual parameter. The algorithm is designed to improve
convergence of large batch training.
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARC(optim)
```
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARC(optim)
optim = apex.fp16_utils.FP16_Optimizer(optim)
```
Args:
optimizer: Pytorch optimizer to wrap and modify learning rate for.
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
eps: epsilon kludge to help with numerical stability while calculating adaotive_lr
"""
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = epsilon
self.eps = eps
self.clip = clip
def __getstate__(self):
return self.optim.__getstate__()
......@@ -43,10 +77,20 @@ class LARC(object):
if p.grad is None:
continue
param_norm = torch.norm(p.data)
# calculate adaptive lr + weight decay
adaptive_lr = (param_norm + self.eps) / (torch.norm(p.grad.data) + param_norm * weight_decay + self.eps)
p.grad.data += weight_decay * p.data
p.grad.data *= self.trust_coefficient * adaptive_lr
grad_norm = torch.norm(p.grad.data)
if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)
# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr/group['lr'], 1)
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment