"...git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "4726d5be84307a094f0e557272f9e0530ae84395"
Unverified Commit 889d871b authored by Raul Puri's avatar Raul Puri Committed by GitHub
Browse files

Create LARC.py

parent 2d5b71bd
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class LARC(object):
def __init__(self, optimizer, trust_coefficient=0.02, epsilon=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = epsilon
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
def __repr__(self):
return self.optim.__repr__()
def state_dict(self):
return self.optim.state_dict()
def load_state_dict(self, state_dict):
self.optim.load_state_dict(state_dict)
def zero_grad(self):
self.optim.zero_grad()
def add_param_group(self, param_group):
self.optim.add_param_group( param_group)
def step(self):
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
# calculate adaptive lr + weight decay
adaptive_lr = (param_norm + self.eps) / (torch.norm(p.grad.data) + param_norm * weight_decay + self.eps)
p.grad.data += weight_decay * p.data
p.grad.data *= self.trust_coefficient * adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment