"vscode:/vscode.git/clone" did not exist on "f61295d944c7e10c7e1e16f6ca6f0b352d8e3545"
Commit ad50ce9a authored by Kexin Yu's avatar Kexin Yu
Browse files

add test case for non-zero weight decay

parent cd3d6d12
......@@ -190,9 +190,11 @@ class TestFusedLAMB(unittest.TestCase):
def gen_single_type_test(self, param_type=torch.float):
nelem = 278011
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=param_type, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
......@@ -214,8 +216,10 @@ class TestFusedLAMB(unittest.TestCase):
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':0}
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
......@@ -232,9 +236,11 @@ class TestFusedLAMB(unittest.TestCase):
def test_lamb_option(self):
nelem = 1
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':0}
tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment