Commit d43ea972 authored by Phil Wang's avatar Phil Wang
Browse files

make sure interface is correct

parent 7247cb45
...@@ -9,30 +9,21 @@ class Lion(Optimizer1State): ...@@ -9,30 +9,21 @@ class Lion(Optimizer1State):
def __init__( def __init__(
self, self,
params, params,
lr=1e-2, lr=1e-4,
alpha=0.99, betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0, weight_decay=0,
momentum=0,
centered=False,
optim_bits=32, optim_bits=32,
args=None, args=None,
min_8bit_size=4096, min_8bit_size=4096,
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
(alpha, momentum), betas,
eps, 0.,
weight_decay, weight_decay,
optim_bits, optim_bits,
args, args,
...@@ -46,29 +37,20 @@ class Lion8bit(Optimizer1State): ...@@ -46,29 +37,20 @@ class Lion8bit(Optimizer1State):
def __init__( def __init__(
self, self,
params, params,
lr=1e-2, lr=1e-4,
alpha=0.99, betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0, weight_decay=0,
momentum=0,
centered=False,
args=None, args=None,
min_8bit_size=4096, min_8bit_size=4096,
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
(alpha, momentum), betas,
eps, 0.,
weight_decay, weight_decay,
8, 8,
args, args,
...@@ -82,30 +64,20 @@ class Lion32bit(Optimizer1State): ...@@ -82,30 +64,20 @@ class Lion32bit(Optimizer1State):
def __init__( def __init__(
self, self,
params, params,
lr=1e-2, lr=1e-4,
alpha=0.99, betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0, weight_decay=0,
momentum=0,
centered=False,
args=None, args=None,
min_8bit_size=4096, min_8bit_size=4096,
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
(alpha, momentum), betas,
eps, 0.,
weight_decay, weight_decay,
32, 32,
args, args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment