Unverified Commit 70d2e3b7 authored by Christian Sarofeen's avatar Christian Sarofeen Committed by GitHub
Browse files

Merge pull request #3 from raulpuric/patch-2

Create LARC.py
parents 2d5b71bd 889d871b
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class LARC(object):
def __init__(self, optimizer, trust_coefficient=0.02, epsilon=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = epsilon
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
def __repr__(self):
return self.optim.__repr__()
def state_dict(self):
return self.optim.state_dict()
def load_state_dict(self, state_dict):
self.optim.load_state_dict(state_dict)
def zero_grad(self):
self.optim.zero_grad()
def add_param_group(self, param_group):
self.optim.add_param_group( param_group)
def step(self):
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
# calculate adaptive lr + weight decay
adaptive_lr = (param_norm + self.eps) / (torch.norm(p.grad.data) + param_norm * weight_decay + self.eps)
p.grad.data += weight_decay * p.data
p.grad.data *= self.trust_coefficient * adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment