Commit cd3d6d12 authored by Kexin Yu's avatar Kexin Yu
Browse files

test nvlamb; hyperparams consistent with adam/adagrad tests

parent 9774ce0d
...@@ -144,7 +144,7 @@ class RefLAMB(Optimizer): ...@@ -144,7 +144,7 @@ class RefLAMB(Optimizer):
class TestFusedLAMB(unittest.TestCase): class TestFusedLAMB(unittest.TestCase):
def setUp(self, max_abs_diff=1e-2, max_rel_diff=1, iters=7): def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff self.max_rel_diff = max_rel_diff
self.iters = iters self.iters = iters
...@@ -161,7 +161,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -161,7 +161,7 @@ class TestFusedLAMB(unittest.TestCase):
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = RefLAMB(ref_param, **lamb_option) ref_optim = RefLAMB(ref_param, **lamb_option)
tst_optim = apex.optimizers.FusedLAMB(tst_param, **lamb_option) tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
...@@ -190,7 +190,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -190,7 +190,7 @@ class TestFusedLAMB(unittest.TestCase):
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float):
nelem = 278011 nelem = 278011
lamb_option = {'lr':1e-3, 'betas':(0.9, 0.999), 'eps':1e-06, 'weight_decay':0.0} lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
...@@ -214,7 +214,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -214,7 +214,7 @@ class TestFusedLAMB(unittest.TestCase):
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
lamb_option = {'lr':1e-3, 'betas':(0.9, 0.999), 'eps':1e-06, 'weight_decay':0.0} lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0}
tensors = [] tensors = []
for size in sizes: for size in sizes:
...@@ -232,7 +232,7 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -232,7 +232,7 @@ class TestFusedLAMB(unittest.TestCase):
def test_lamb_option(self): def test_lamb_option(self):
nelem = 1 nelem = 1
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':0.0} lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=torch.float, device='cuda') tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment