Unverified Commit 11faaca7 authored by Mannat Singh's avatar Mannat Singh Committed by GitHub
Browse files

Return internal optimizer's param_groups from LARC (#767)

parent ca00adac
......@@ -37,7 +37,6 @@ class LARC(object):
"""
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = eps
......@@ -52,6 +51,14 @@ class LARC(object):
def __repr__(self):
return self.optim.__repr__()
@property
def param_groups(self):
return self.optim.param_groups
@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value
def state_dict(self):
return self.optim.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment