"vscode:/vscode.git/clone" did not exist on "194a084640bca80f99af121c2cd64755c31f64f3"
Commit ad50ce9a authored by Kexin Yu's avatar Kexin Yu
Browse files

add test case for non-zero weight decay

parent cd3d6d12
...@@ -190,9 +190,11 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -190,9 +190,11 @@ class TestFusedLAMB(unittest.TestCase):
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float):
nelem = 278011 nelem = 278011
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option) self.gen_param_optim([tensor], lamb_option)
...@@ -214,8 +216,10 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -214,8 +216,10 @@ class TestFusedLAMB(unittest.TestCase):
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0} weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
tensors = [] tensors = []
for size in sizes: for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
...@@ -232,9 +236,11 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -232,9 +236,11 @@ class TestFusedLAMB(unittest.TestCase):
def test_lamb_option(self): def test_lamb_option(self):
nelem = 1 nelem = 1
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=torch.float, device='cuda') tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \ ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option) self.gen_param_optim([tensor], lamb_option)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment