Unverified Commit 11faaca7 authored by Mannat Singh's avatar Mannat Singh Committed by GitHub
Browse files

Return internal optimizer's param_groups from LARC (#767)

parent ca00adac
...@@ -37,7 +37,6 @@ class LARC(object): ...@@ -37,7 +37,6 @@ class LARC(object):
""" """
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer self.optim = optimizer
self.trust_coefficient = trust_coefficient self.trust_coefficient = trust_coefficient
self.eps = eps self.eps = eps
...@@ -52,6 +51,14 @@ class LARC(object): ...@@ -52,6 +51,14 @@ class LARC(object):
def __repr__(self): def __repr__(self):
return self.optim.__repr__() return self.optim.__repr__()
@property
def param_groups(self):
return self.optim.param_groups
@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value
def state_dict(self): def state_dict(self):
return self.optim.state_dict() return self.optim.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment