Commit ad50ce9a authored by Kexin Yu's avatar Kexin Yu
Browse files

add test case for non-zero weight decay

parent cd3d6d12
...@@ -190,20 +190,22 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -190,20 +190,22 @@ class TestFusedLAMB(unittest.TestCase):
def gen_single_type_test(self, param_type=torch.float): def gen_single_type_test(self, param_type=torch.float):
nelem = 278011 nelem = 278011
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=param_type, device='cuda') tensor = torch.rand(nelem, dtype=param_type, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ weight_decay = [0, 0.01]
self.gen_param_optim([tensor], lamb_option)
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_param, tst_param) self.gen_grad(ref_param, tst_param)
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self): def test_float(self):
self.gen_single_type_test(param_type=torch.float) self.gen_single_type_test(param_type=torch.float)
...@@ -214,38 +216,42 @@ class TestFusedLAMB(unittest.TestCase): ...@@ -214,38 +216,42 @@ class TestFusedLAMB(unittest.TestCase):
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0} weight_decay = [0, 0.01]
tensors = [] for wd in weight_decay:
for size in sizes: lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) tensors = []
ref_param, tst_param, ref_optim, tst_optim = \ for size in sizes:
self.gen_param_optim(tensors, lamb_option) tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
ref_param, tst_param, ref_optim, tst_optim = \
for i in range(self.iters): self.gen_param_optim(tensors, lamb_option)
self.gen_grad(ref_param, tst_param)
ref_optim.step() for i in range(self.iters):
tst_optim.step() self.gen_grad(ref_param, tst_param)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) ref_optim.step()
self.assertLessEqual(max_abs_diff, self.max_abs_diff) tst_optim.step()
self.assertLessEqual(max_rel_diff, self.max_rel_diff) max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_lamb_option(self): def test_lamb_option(self):
nelem = 1 nelem = 1
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=torch.float, device='cuda') tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
ref_param, tst_param, ref_optim, tst_optim = \ weight_decay = [0, 0.01]
self.gen_param_optim([tensor], lamb_option)
for wd in weight_decay:
for i in range(self.iters): lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
self.gen_grad(ref_param, tst_param) ref_param, tst_param, ref_optim, tst_optim = \
ref_optim.step() self.gen_param_optim([tensor], lamb_option)
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff) ref_optim.step()
self.assertLessEqual(max_rel_diff, self.max_rel_diff) tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment